@@ -57,6 +57,49 @@ def run_diffusers_mmdit(
5757
5858 return noise_pred .numpy ()
5959
60+ def run_attn_turbine (q , k , v , args ):
61+ attn_runner = vmfbRunner (
62+ args .device ,
63+ args .vmfb_path ,
64+ None ,
65+ )
66+ iree_inputs = [
67+ ireert .asdevicearray (attn_runner .config .device , q ),
68+ ireert .asdevicearray (attn_runner .config .device , k ),
69+ ireert .asdevicearray (attn_runner .config .device , v ),
70+ ]
71+ attn_output = attn_runner .ctx .modules .compiled_attn ["run_forward" ](
72+ * iree_inputs
73+ ).to_host ()
74+ return attn_output
75+
76+ @torch .no_grad ()
77+ def run_attn_torch (q , k , v , args ):
78+ from turbine_models .custom_models .sd3_inference .sd3_mmdit import MMDiTAttention
79+
80+ mmdit_attn = MMDiTAttention ()
81+ attn_output = mmdit_attn .forward (
82+ torch .tensor (q , dtype = torch .float32 ),
83+ torch .tensor (k , dtype = torch .float32 ),
84+ torch .tensor (v , dtype = torch .float32 ),
85+ )
86+
87+ return attn_output .numpy ()
88+
89+ def find_errs (turbine_output , torch_output , dim = [], failed_dims = [], errs = []):
90+ if not np .allclose (turbine_output , torch_output , rtol = 4e-2 , atol = 4e-2 ):
91+ if turbine_output .ndim > 0 :
92+ orig_dim = dim
93+ for idx , i in enumerate (torch_output ):
94+ dim = [* orig_dim , idx ]
95+ try :
96+ np .testing .assert_allclose (turbine_output [idx ], torch_output [idx ], rtol = 4e-2 , atol = 4e-2 )
97+ except Exception as e :
98+ err = np .abs (turbine_output [idx ] - torch_output [idx ])
99+ failed_dims .append (dim )
100+ errs .append ([err , turbine_output [idx ], torch_output [idx ]])
101+ failed_dims , errs = find_errs (turbine_output [idx ], torch_output [idx ], dim , failed_dims , errs )
102+ return (failed_dims , errs )
60103
61104if __name__ == "__main__" :
62105 from turbine_models .custom_models .sd3_inference .sd3_cmd_opts import args
@@ -69,6 +112,29 @@ def run_diffusers_mmdit(
69112 dtype = torch .float16
70113 else :
71114 dtype = torch .float32
115+
116+ if args .attn_repro :
117+ qkv_shape = (2 , 24 , 4250 , 64 )
118+ example_qkv = [
119+ np .load ("q.npy" ).astype (np .float16 ),
120+ np .load ("k.npy" ).astype (np .float16 ),
121+ np .load ("v.npy" ).astype (np .float16 ),
122+ ]
123+ turbine_output = run_attn_turbine (
124+ * example_qkv ,
125+ args ,
126+ )
127+ torch_output = run_attn_torch (* example_qkv , args ).astype (np .float16 )
128+ np .save ("turbine_attn_output.npy" , turbine_output )
129+ np .save ("torch_attn_output.npy" , torch_output )
130+ failed_dims , errs = find_errs (turbine_output , torch_output )
131+ for idx , dim in enumerate (failed_dims ):
132+ if len (dim ) == len (torch_output .shape ):
133+ print ("Failed dimension: " , dim , " with error: " , errs [idx ][0 ])
134+ print ("Turbine output: " , errs [idx ][1 ])
135+ print ("Torch output: " , errs [idx ][2 ])
136+ print (torch_output .shape )
137+ exit ()
72138
73139 batch_size = args .batch_size * 2 #do classifier free guidance
74140 hidden_states = torch .randn (
0 commit comments