Skip to content

Commit 8dc9f5f

Browse files
authored
[Relax][PyTorch] Fix KeyError: dtype when converting PyTorch model with gradient checkpointing using torch.export (#18461)
This PR is trying to fix issues #18439. Co-authored-by: cchung100m <[email protected]>
1 parent ea89f21 commit 8dc9f5f

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2036,7 +2036,11 @@ def _arange(self, node: fx.Node) -> relax.Var:
20362036
return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype))
20372037

20382038
def _empty(self, node: fx.Node) -> relax.Var:
2039-
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
2039+
import torch
2040+
2041+
dtype = self._convert_data_type(
2042+
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
2043+
)
20402044
return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
20412045

20422046
def _empty_like(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,7 @@ def create_convert_map(
11431143
"_assert_tensor_metadata.default": lambda node: self.env[
11441144
node.args[0]
11451145
], # metadata assertion: no-op
1146+
"empty.default": self._empty,
11461147
"empty.memory_format": self._empty,
11471148
"empty_permuted.default": self._empty, # Similar to empty with permuted layout
11481149
"empty_like.default": self._empty_like,

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5278,6 +5278,27 @@ def main(
52785278
verify_model(Empty(), example_args, {}, Expected, run_ep_decomposition=True)
52795279

52805280

5281+
def test_empty_without_dtype():
5282+
class EmptyWithoutDtype(Module):
5283+
def forward(self, input):
5284+
return torch.empty((5, 5))
5285+
5286+
@tvm.script.ir_module
5287+
class Expected:
5288+
@R.function
5289+
def main(
5290+
input: R.Tensor((10, 10), dtype="float32")
5291+
) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
5292+
with R.dataflow():
5293+
lv: R.Tensor((5, 5), dtype="float32") = R.zeros(R.shape([5, 5]), dtype="float32")
5294+
gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
5295+
R.output(gv)
5296+
return gv
5297+
5298+
example_args = (torch.randn(10, 10, dtype=torch.float32),)
5299+
verify_model(EmptyWithoutDtype(), example_args, {}, Expected, run_ep_decomposition=True)
5300+
5301+
52815302
def test_fill():
52825303
class Fill(Module):
52835304
def forward(self, input: torch.Tensor):

0 commit comments

Comments
 (0)