Skip to content

Commit 9662d81

Browse files
authored
feat: update graph test helper for testing new inf optimizer (#106)
Signed-off-by: haoguo <[email protected]>
1 parent b8c5b9c commit 9662d81

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.export import export
99
from torch.fx import GraphModule
1010

11+
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
1112
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
1213
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo
1314

@@ -20,6 +21,21 @@ def build_model(self, device: str) -> nn.Module:
2021
return self.model.to(device=device)
2122

2223

24+
class SequenceEmbeddingInfo(SequenceInfo):
25+
hidden_size: int
26+
dtype: torch.dtype
27+
28+
def set_example_sequence(self) -> None:
29+
super().set_example_sequence()
30+
# set input ids to a 3D tensor (actually input embeddings)
31+
self.input_ids = torch.rand(
32+
*self.input_ids.shape,
33+
self.hidden_size,
34+
device=self.input_ids.device,
35+
dtype=self.dtype,
36+
)
37+
38+
2339
def count_parameters(model: torch.nn.Module):
2440
for n, p in model.named_parameters():
2541
print(n, p.shape)
@@ -32,6 +48,79 @@ def count_buffers(model: torch.nn.Module):
3248
return sum(np.prod(b.shape) for b in model.buffers())
3349

3450

51+
def run_test_transformed_gm(
52+
model: nn.Module,
53+
x: torch.Tensor,
54+
gm_transformed: GraphModule,
55+
check_transformed_graph: Callable[[GraphModule], bool],
56+
_get_expected_num_params: Callable[[int], int],
57+
atol: float = 1e-3,
58+
rtol: float = 1e-3,
59+
test_load_hook: bool = True,
60+
strict_loading: bool = True,
61+
dynamic_shapes: Dict = None,
62+
skip_output_assert: bool = False,
63+
*args, # Additional arguments for transform
64+
) -> GraphModule:
65+
# run model once
66+
y_model = model(x)
67+
68+
# num params
69+
num_params_model = count_parameters(model)
70+
print(num_params_model)
71+
72+
# export + check (we clone the state dict to have a bit more freedom in testing below)
73+
gm_ref = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
74+
print(gm_ref)
75+
y_gm = gm_ref(x)
76+
num_params_gm = count_parameters(gm_ref)
77+
78+
assert num_params_model == num_params_gm
79+
if not skip_output_assert:
80+
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
81+
82+
print(gm_transformed)
83+
# in case buffers or other tensors were added during the transform
84+
gm_transformed = gm_transformed.to("cuda")
85+
y_transformed = gm_transformed(x)
86+
n_p_transformed = count_parameters(gm_transformed)
87+
88+
n_p_t_expected = _get_expected_num_params(num_params_model)
89+
assert n_p_transformed == n_p_t_expected, (
90+
f"actual params {n_p_transformed} != expected params {n_p_t_expected}"
91+
)
92+
93+
# check if the transformation worked
94+
assert check_transformed_graph(gm_transformed)
95+
96+
if strict_loading and not skip_output_assert:
97+
# check if output equals without loading state dict
98+
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)
99+
100+
if test_load_hook and not skip_output_assert:
101+
# check if loading hook works from original state dict
102+
reset_parameters(gm_transformed)
103+
y_random = gm_transformed(x)
104+
assert not all_close(y_model, y_random), f"{y_model=}, {y_random=}"
105+
106+
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
107+
y_loaded_from_original = gm_transformed(x)
108+
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
109+
110+
# check if loading hook works from state_dict of a transformed model
111+
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
112+
reset_parameters(gm_transformed)
113+
y_random2 = gm_transformed(x)
114+
assert not all_close(y_model, y_random2), f"{y_model=}, {y_random2=}"
115+
116+
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
117+
y_loaded_from_transformed = gm_transformed(x)
118+
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
119+
120+
# check if we can still export the model as expected
121+
export(gm_transformed, args=(x,))
122+
123+
35124
def run_test(
36125
model: nn.Module,
37126
x: torch.Tensor,

0 commit comments

Comments
 (0)