Skip to content

Commit 34bb9c4

Browse files
sevenEngpytorchmergebot
authored andcommitted
[AOTI] Fix unknown constant type for device-moved constants (pytorch#168138)
### Issue When we have the flag `use_runtime_constant_folding=False`, if we move a constant (buffer or parameter) to a different device, we'll generate a new buf/param during compilation time with a new name where the new device (+counter) will be appended, e.g.: ``` # noramlised name orig buf: model_x_submodule_y_buf0_name moved buf: model_x_submodule_y_buf0_name_cpu0 ``` However, these new names are not registered in `V.graph.constants`. During cpp wrapper code generation, they won't be recognised, hence will get the `ConstantType::Unknown`. It'll cause issues for model loading during runtime. https://github.com/pytorch/pytorch/blob/b8a3165d28b672ac6d84128e66265bf471b92a55/torch/_inductor/codegen/cpp_wrapper_cpu.py#L851-L862 ### Fix After we do the new const name allocation following device movement, check if the original constant is any recognised buffer or parameter, if so, register the new ones with graph as well. ### Failed Unittest before the patch ``` =========================================================================== short test summary info ============================================================================ FAILED [3.9054s] test/inductor/test_aot_inductor.py::AOTInductorTestABICompatibleCpu::test_device_moved_constant_cpu - RuntimeError: Expected to not find "torch::aot_inductor::ConstantType::Unknown" but found it FAILED [3.1852s] test/inductor/test_aot_inductor.py::AOTInductorTestABICompatibleGpu::test_device_moved_constant_cuda - RuntimeError: Expected to not find "torch::aot_inductor::ConstantType::Unknown" but found it ================================================================ 2 failed, 1 skipped, 916 deselected in 11.81s ================================================================= ``` cc. @muchulee8 @desertfire Pull Request resolved: pytorch#168138 Approved by: https://github.com/muchulee8
1 parent 7a064ed commit 34bb9c4

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,49 @@ def forward(self, x):
671671
code
672672
)
673673

674+
@requires_gpu
675+
def test_device_moved_constant(self):
676+
# testing both directions
677+
device_movements = [
678+
(torch.device(type=GPU_TYPE, index=0), torch.device("cpu")),
679+
(torch.device("cpu"), torch.device(type=GPU_TYPE, index=0)),
680+
]
681+
682+
class Model(torch.nn.Module):
683+
def __init__(self, from_device):
684+
super().__init__()
685+
self.register_buffer("_buf", torch.randn(6, 7, device=from_device))
686+
self._param = torch.nn.Parameter(
687+
torch.rand(6, 7, device=from_device), requires_grad=False
688+
)
689+
690+
def forward(self, x):
691+
to_device = x.device
692+
moved_buf = self._buf.to(to_device)
693+
moved_param = self._param.to(to_device)
694+
return moved_buf, moved_param
695+
696+
with config.patch(
697+
{
698+
"aot_inductor.use_runtime_constant_folding": False,
699+
}
700+
):
701+
for from_device, to_device in device_movements:
702+
model = Model(from_device)
703+
example_inputs = (torch.randn(6, 7, device=to_device),)
704+
_, code = run_and_get_cpp_code(
705+
AOTIRunnerUtil.compile, model, example_inputs
706+
)
707+
FileCheck().check_not("torch::aot_inductor::ConstantType::Unknown").run(
708+
code
709+
)
710+
FileCheck().check_count(
711+
"torch::aot_inductor::ConstantType::Buffer", 2, exactly=True
712+
).run(code)
713+
FileCheck().check_count(
714+
"torch::aot_inductor::ConstantType::Parameter", 2, exactly=True
715+
).run(code)
716+
674717
def test_subclasses(self):
675718
device_to_init = self.device
676719

torch/_inductor/graph.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,35 @@ def constant_name(self, name: str, device_override: Optional[torch.device]) -> s
11141114
with torch.utils._python_dispatch._disable_current_modes():
11151115
# caller might have OrderedSet fake tensor mode which will create a fake tensor
11161116
# when calling .to, so unset modes here
1117-
return self.allocate_non_dup_const_name(
1117+
non_dup_const_name = self.allocate_non_dup_const_name(
11181118
f"{name}_{device_override.type}{device_override.index or 0}",
11191119
self.constants[name].to(device_override),
11201120
)
11211121

1122+
assert non_dup_const_name in self.constants, (
1123+
f"{non_dup_const_name} should be in V.graph.constants already"
1124+
)
1125+
1126+
# register device-copied buffers and parameters to graph as well
1127+
# to codegen correct torch::aot_inductor::ConstantType for them rather than `Unknown`
1128+
if any(
1129+
name == normalize_name(buffer_name)
1130+
for buffer_name in self.named_buffers
1131+
):
1132+
self.named_buffers[non_dup_const_name] = self.constants[
1133+
non_dup_const_name
1134+
]
1135+
1136+
if any(
1137+
name == normalize_name(param_name)
1138+
for param_name in self.named_parameters
1139+
):
1140+
self.named_parameters[non_dup_const_name] = self.constants[
1141+
non_dup_const_name
1142+
]
1143+
1144+
return non_dup_const_name
1145+
11221146
# pyrefly: ignore [bad-override]
11231147
def placeholder(
11241148
self,

0 commit comments

Comments
 (0)