Skip to content

Commit a5e85a9

Browse files
authored
Fix training artifacts for 2GB+ models and MSELoss (#22414)
1 parent 6407d81 commit a5e85a9

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

orttraining/orttraining/python/training/onnxblock/blocks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def __call__(self, *args, **kwargs):
5454
output = self.build(*args, **kwargs)
5555

5656
if accessor._GLOBAL_ACCESSOR.has_path:
57+
# `save` will destructively access any external data
58+
copied_model = copy.deepcopy(accessor._GLOBAL_ACCESSOR.model)
5759
onnx.save(
58-
accessor._GLOBAL_ACCESSOR.model,
60+
copied_model,
5961
self.temp_onnx_file_path,
6062
save_as_external_data=True,
6163
all_tensors_to_one_file=True,

orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,8 @@ def test_generate_artifacts_external_data_one_file():
11591159
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))
11601160

11611161

1162-
def test_generate_artifacts_external_data_separate_files():
1162+
@pytest.mark.parametrize("loss", [loss_t for loss_t in artifacts.LossType])
1163+
def test_generate_artifacts_external_data_separate_files(loss):
11631164
with tempfile.TemporaryDirectory() as temp_dir:
11641165
_, simple_net = _get_models("cpu", 32, 28, 10, 10)
11651166

@@ -1176,7 +1177,7 @@ def test_generate_artifacts_external_data_separate_files():
11761177
artifacts.generate_artifacts(
11771178
os.path.join(temp_dir, "simple_net.onnx"),
11781179
requires_grad=requires_grad_params,
1179-
loss=artifacts.LossType.CrossEntropyLoss,
1180+
loss=loss,
11801181
optimizer=artifacts.OptimType.AdamW,
11811182
artifact_directory=temp_dir,
11821183
)

0 commit comments

Comments
 (0)