Skip to content

Commit 06635b4

Browse files
committed
use kwargs
1 parent 299cb1c commit 06635b4

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

examples/fabric/image_classifier/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def run(hparams):
158158
# When using distributed training, use `fabric.save`
159159
# to ensure the current process is allowed to save a checkpoint
160160
if hparams.save_model:
161-
fabric.save("mnist_cnn.pt", model.state_dict())
161+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
162162

163163

164164
if __name__ == "__main__":

examples/fabric/kfold_cv/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def run(hparams):
161161
# When using distributed training, use `fabric.save`
162162
# to ensure the current process is allowed to save a checkpoint
163163
if hparams.save_model:
164-
fabric.save("mnist_cnn.pt", model.state_dict())
164+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
165165

166166

167167
if __name__ == "__main__":

examples/fabric/tensor_parallel/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def train():
6767
# See `fabric consolidate --help` if you need to convert the checkpoint to a single file
6868
fabric.print("Saving a (distributed) checkpoint ...")
6969
state = {"model": model, "optimizer": optimizer, "iteration": i}
70-
fabric.save("checkpoint.pt", state)
70+
fabric.save(path="checkpoint.pt", state=state)
7171

7272
fabric.print("Training successfully completed!")
7373
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

0 commit comments

Comments
 (0)