Skip to content

Commit 5af71da

Browse files
build: manually update PyTorch version (#4102)
This commit sets the PyTorch and TorchVision versions to nightly release 2025-03-25. This commit also adds the `strict` flag (by default set to False) in fx's `export_and_import` method in accordance with the changes made here pytorch/pytorch@ab45aac. It also updates the `torch.export` call from the `_export_run` method in fx_importer_backend by adding the `strict` flag. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 5698071 commit 5af71da

File tree

9 files changed

+16
-13
lines changed

9 files changed

+16
-13
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@
520520
"ReflectionPad3dModuleBack_basic",
521521
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
522522
"SliceOutOfLowerBoundEndIndexModule_basic",
523+
"NativeGroupNormModule_basic",
523524
}
524525

525526
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@@ -954,6 +955,7 @@
954955
"AtenSymConstrainRange_basic",
955956
"AtenSymConstrainRangeForSize_basic",
956957
"Aten_AssertScalar_basic",
958+
"NativeGroupNormModule_basic",
957959
}
958960

959961
FX_IMPORTER_STABLEHLO_CRASHING_SET = {

projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def invoke_func(*torch_inputs):
134134
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
135135
result: Trace = []
136136
for item in trace:
137-
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
137+
prog: ExportedProgram = torch.export.export(
138+
artifact, tuple(item.inputs), strict=True
139+
)
138140
module = fx.export_and_import(
139141
prog,
140142
output_type=self._output_type,

python/torch_mlir/fx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def export_and_import(
7373
output_type: Union[str, OutputType] = OutputType.RAW,
7474
fx_importer: Optional[FxImporter] = None,
7575
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
76+
strict: bool = False,
7677
experimental_support_mutation: bool = False,
7778
import_symbolic_shape_expressions: bool = False,
7879
hooks: Optional[FxImporterHooks] = None,
@@ -94,7 +95,9 @@ def export_and_import(
9495
else:
9596
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
9697
if version.Version(torch.__version__) >= version.Version("2.2.0"):
97-
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
98+
prog = torch.export.export(
99+
f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
100+
)
98101
else:
99102
prog = torch.export.export(f, args, kwargs)
100103
if decomposition_table is None:

pytorch-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cdb42bd8cc05bef0ec9b682b274c2acb273f2d62
1+
3794824ceb12a9d4396eaa17795bf2147fd9e1c3

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torch/
22
--pre
3-
torch==2.7.0.dev20250310
3+
torch==2.8.0.dev20250325

test/python/fx_importer/basic_test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(f):
3131
@run
3232
# CHECK-LABEL: test_import_frozen_exported_program
3333
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
34-
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
34+
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
3535
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
3636
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
3737
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
@@ -42,7 +42,6 @@ def run(f):
4242
#
4343
# Validate dialect resources exist.
4444
# CHECK: dialect_resources:
45-
# CHECK-DAG: torch_tensor_1_4_torch.float32
4645
# CHECK-DAG: torch_tensor_3_1_torch.float32
4746
def test_import_frozen_exported_program():
4847
# Tests the basic structural premises of import_frozen_exported_program,
@@ -210,11 +209,7 @@ def forward(self):
210209
@run
211210
# CHECK-LABEL: test_stack_trace
212211
# CHECK: #loc[[LOC1:.+]] = loc(
213-
# CHECK: #loc[[LOC2:.+]] = loc(
214-
# CHECK: #loc[[LOC3:.+]] = loc(
215-
# CHECK: #loc[[LOC4:.+]] = loc(callsite(#loc[[LOC2]] at #loc[[LOC3]]))
216-
# CHECK: #loc[[LOC5:.+]] = loc(callsite(#loc[[LOC1]] at #loc[[LOC4]]))
217-
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC4]])
212+
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC1]])
218213
def test_stack_trace():
219214
class Basic(nn.Module):
220215
def __init__(self):

test/python/fx_importer/symbolic_shape_expr_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def forward(self, x):
222222
SliceTensorDynamicOutput(),
223223
x,
224224
dynamic_shapes=dynamic_shapes,
225+
strict=True,
225226
import_symbolic_shape_expressions=True,
226227
)
227228
print(m)

test/python/fx_importer/v2.3/mutation_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run(f):
3131
# This doesn't do mutation but ensures that the basics remain functional.
3232
# CHECK-LABEL: test_import_frozen_exported_program
3333
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
34-
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
34+
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
3535
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
3636
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
3737
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]

torchvision-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
22
--pre
3-
torchvision==0.22.0.dev20250310
3+
torchvision==0.22.0.dev20250325

0 commit comments

Comments
 (0)