Skip to content

Commit 6d285bb

Browse files
chunnienccopybara-github
authored andcommitted
fix hf distilbert conversion
PiperOrigin-RevId: 713093948
1 parent b91c2e9 commit 6d285bb

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

ai_edge_torch/odml_torch/_torch_future.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,16 @@ def safe_run_decompositions(exported_program, decomp_table=None):
7373
node.target = lambda self, size: torch.reshape(self.contiguous(), size)
7474

7575
return exported_program.run_decompositions(decomp_table)
76+
77+
78+
def dummy_decomp_table():
79+
"""Build dummy decomp table for run_decompositions without any decompositions.
80+
81+
Compatible for torch<=2.5.
82+
83+
Returns:
84+
Decomp table for ExportedProgram.run_decompositions.
85+
"""
86+
return {
87+
torch._ops.OperatorBase(): lambda: None,
88+
}

ai_edge_torch/odml_torch/export.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
238238
def in_i32(x: int):
239239
return -2147483648 <= x <= 2147483647
240240

241+
def to_int32(x: torch.Tensor):
242+
return torch.ops.aten._to_copy.default(x, dtype=torch.int32)
243+
241244
def rewrite_arange(node: torch.fx.Node):
242245
tensor_meta = node.meta.get("tensor_meta", None)
243246
if not tensor_meta:
@@ -249,7 +252,7 @@ def rewrite_arange(node: torch.fx.Node):
249252
if not (in_i32(start) and in_i32(end)):
250253
return
251254
op = node.target
252-
node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32)
255+
node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs))
253256

254257
graph_module = exported_program.graph_module
255258
for node in graph_module.graph.nodes:
@@ -305,8 +308,9 @@ def exported_program_to_mlir(
305308

306309
_convert_i64_to_i32(exported_program)
307310

311+
# No decompositions but just retracing/cananicalization.
308312
exported_program = _torch_future.safe_run_decompositions(
309-
exported_program, lowerings.decompositions()
313+
exported_program, _torch_future.dummy_decomp_table()
310314
)
311315

312316
# Passes below mutate the exported program to a state not executable by torch.

ai_edge_torch/odml_torch/lowerings/decomp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def decompositions():
5555
],
5656
)
5757

58+
# Override noop aten op decompositions for faster run_decompositions.
59+
decompositions[torch.ops.aten.alias.default] = lambda x: x
60+
decompositions[torch.ops.aten.detach.default] = lambda x: x
61+
5862
# Override _safe_softmax decompositions with regular softmax.
5963
# _safe_softmax introduces additional check-select ops to guard extreme
6064
# input values to softmax, which could make the converted model inefficient

0 commit comments

Comments
 (0)