Skip to content

Commit 89dbc55

Browse files
authored
Fix fabric examples and load_checkpoint hparams ref (#21013)
* fix examples * fix reference to hparams * use kwargs
1 parent 11968ba commit 89dbc55

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ The LightningModule also has access to the Hyperparameters
111111
.. code-block:: python
112112
113113
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
114-
print(model.learning_rate)
114+
print(model.hparams.learning_rate)
115115
116116
----
117117

examples/fabric/image_classifier/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def run(hparams):
158158
# When using distributed training, use `fabric.save`
159159
# to ensure the current process is allowed to save a checkpoint
160160
if hparams.save_model:
161-
fabric.save(model.state_dict(), "mnist_cnn.pt")
161+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
162162

163163

164164
if __name__ == "__main__":

examples/fabric/kfold_cv/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def run(hparams):
161161
# When using distributed training, use `fabric.save`
162162
# to ensure the current process is allowed to save a checkpoint
163163
if hparams.save_model:
164-
fabric.save(model.state_dict(), "mnist_cnn.pt")
164+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
165165

166166

167167
if __name__ == "__main__":

examples/fabric/tensor_parallel/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def train():
6767
# See `fabric consolidate --help` if you need to convert the checkpoint to a single file
6868
fabric.print("Saving a (distributed) checkpoint ...")
6969
state = {"model": model, "optimizer": optimizer, "iteration": i}
70-
fabric.save("checkpoint.pt", state)
70+
fabric.save(path="checkpoint.pt", state=state)
7171

7272
fabric.print("Training successfully completed!")
7373
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

0 commit comments

Comments
 (0)