88from torch .export import export
99from torch .fx import GraphModule
1010
11+ from tensorrt_llm ._torch .auto_deploy .custom_ops .attention_interface import SequenceInfo
1112from tensorrt_llm ._torch .auto_deploy .export import torch_export_to_gm
1213from tensorrt_llm ._torch .auto_deploy .transformations .library .sharding import ShardingTransformInfo
1314
@@ -20,6 +21,21 @@ def build_model(self, device: str) -> nn.Module:
2021 return self .model .to (device = device )
2122
2223
24+ class SequenceEmbeddingInfo (SequenceInfo ):
25+ hidden_size : int
26+ dtype : torch .dtype
27+
28+ def set_example_sequence (self ) -> None :
29+ super ().set_example_sequence ()
30+ # set input ids to a 3D tensor (actually input embeddings)
31+ self .input_ids = torch .rand (
32+ * self .input_ids .shape ,
33+ self .hidden_size ,
34+ device = self .input_ids .device ,
35+ dtype = self .dtype ,
36+ )
37+
38+
2339def count_parameters (model : torch .nn .Module ):
2440 for n , p in model .named_parameters ():
2541 print (n , p .shape )
@@ -32,6 +48,79 @@ def count_buffers(model: torch.nn.Module):
3248 return sum (np .prod (b .shape ) for b in model .buffers ())
3349
3450
51+ def run_test_transformed_gm (
52+ model : nn .Module ,
53+ x : torch .Tensor ,
54+ gm_transformed : GraphModule ,
55+ check_transformed_graph : Callable [[GraphModule ], bool ],
56+ _get_expected_num_params : Callable [[int ], int ],
57+ atol : float = 1e-3 ,
58+ rtol : float = 1e-3 ,
59+ test_load_hook : bool = True ,
60+ strict_loading : bool = True ,
61+ dynamic_shapes : Dict = None ,
62+ skip_output_assert : bool = False ,
63+ * args , # Additional arguments for transform
64+ ) -> GraphModule :
65+ # run model once
66+ y_model = model (x )
67+
68+ # num params
69+ num_params_model = count_parameters (model )
70+ print (num_params_model )
71+
72+ # export + check (we clone the state dict to have a bit more freedom in testing below)
73+ gm_ref = torch_export_to_gm (model , args = (x ,), dynamic_shapes = (dynamic_shapes ,), clone = True )
74+ print (gm_ref )
75+ y_gm = gm_ref (x )
76+ num_params_gm = count_parameters (gm_ref )
77+
78+ assert num_params_model == num_params_gm
79+ if not skip_output_assert :
80+ torch .testing .assert_close (y_model , y_gm , atol = atol , rtol = rtol )
81+
82+ print (gm_transformed )
83+ # in case buffers or other tensors were added during the transform
84+ gm_transformed = gm_transformed .to ("cuda" )
85+ y_transformed = gm_transformed (x )
86+ n_p_transformed = count_parameters (gm_transformed )
87+
88+ n_p_t_expected = _get_expected_num_params (num_params_model )
89+ assert n_p_transformed == n_p_t_expected , (
90+ f"actual params { n_p_transformed } != expected params { n_p_t_expected } "
91+ )
92+
93+ # check if the transformation worked
94+ assert check_transformed_graph (gm_transformed )
95+
96+ if strict_loading and not skip_output_assert :
97+ # check if output equals without loading state dict
98+ torch .testing .assert_close (y_model , y_transformed , atol = atol , rtol = rtol )
99+
100+ if test_load_hook and not skip_output_assert :
101+ # check if loading hook works from original state dict
102+ reset_parameters (gm_transformed )
103+ y_random = gm_transformed (x )
104+ assert not all_close (y_model , y_random ), f"{ y_model = } , { y_random = } "
105+
106+ gm_transformed .load_state_dict (model .state_dict (), strict = True if strict_loading else False )
107+ y_loaded_from_original = gm_transformed (x )
108+ torch .testing .assert_close (y_model , y_loaded_from_original , atol = atol , rtol = rtol )
109+
110+ # check if loading hook works from state_dict of a transformed model
111+ state_dict_sharded = copy .deepcopy (gm_transformed .state_dict ())
112+ reset_parameters (gm_transformed )
113+ y_random2 = gm_transformed (x )
114+ assert not all_close (y_model , y_random2 ), f"{ y_model = } , { y_random2 = } "
115+
116+ gm_transformed .load_state_dict (state_dict_sharded , strict = True if strict_loading else False )
117+ y_loaded_from_transformed = gm_transformed (x )
118+ torch .testing .assert_close (y_model , y_loaded_from_transformed , atol = atol , rtol = rtol )
119+
120+ # check if we can still export the model as expected
121+ export (gm_transformed , args = (x ,))
122+
123+
35124def run_test (
36125 model : nn .Module ,
37126 x : torch .Tensor ,
0 commit comments