@@ -192,7 +192,7 @@ def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None):
192192 if kind_in_graph is not None :
193193 self .assertTrue (any (n .kind () == kind_in_graph for n in script_graph .nodes ()))
194194 self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
195-
195+
196196 # check if certain node does not exist in the graph
197197 if kind_not_in_graph is not None :
198198 self .assertTrue (all (n .kind () != kind_not_in_graph for n in script_graph .nodes ()))
@@ -205,12 +205,12 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
205205 core .enable_auto_dnnl ()
206206 core .enable_jit_opt ()
207207 core .enable_mix_bf16_fp32 ()
208-
208+
209209 model = model .to (ipex .DEVICE ).eval ()
210210 x = x .to (ipex .DEVICE )
211211 x2 = x .clone ()
212212 x3 = x .clone ()
213-
213+
214214 script_fused_model = torch .jit .script (copy .deepcopy (model ))
215215 trace_fused_model = torch .jit .trace (copy .deepcopy (model ), x3 )
216216
@@ -224,24 +224,24 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
224224 # bf 16, jit trace path
225225 trace_graph = trace_fused_model .graph_for (x3 )
226226 fused_tresult = trace_fused_model (x3 )
227-
228- # disable mix_bf16_fp32 when the calculation is done
227+
228+ # disable mix_bf16_fp32 when the calculation is done
229229 # to avoid affecting other scripts
230230 core .disable_mix_bf16_fp32 ()
231-
231+
232232 self .assertEqual (fused_sresult , result , prec = prec )
233233 self .assertEqual (fused_tresult , result , prec = prec )
234234
235235 # check if the fused node exists in the graph
236236 if kind_in_graph is not None :
237237 self .assertTrue (any (n .kind () == kind_in_graph for n in script_graph .nodes ()))
238238 self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
239-
239+
240240 # check if certain node does not exist in the graph
241241 if kind_not_in_graph is not None :
242242 self .assertTrue (all (n .kind () != kind_not_in_graph for n in script_graph .nodes ()))
243243 self .assertTrue (all (n .kind () != kind_not_in_graph for n in trace_graph .nodes ()))
244-
244+
245245
246246 def test_output_conv_bn_2d (self ):
247247 self ._test_output (
@@ -364,6 +364,23 @@ def test_output_linear_relu(self):
364364 kind_in_graph = "ipex::linear_relu" )
365365
366366
367+ def test_jit_function (self ):
368+ # test hool trace and script can works for function
369+ def fn (input , weight , bias ):
370+ return F .linear (input , weight , bias )
371+
372+ input = torch .randn (2 , 4 )
373+ weight = torch .randn (5 , 4 )
374+ bias = torch .randn (5 )
375+ result = fn (input , weight , bias )
376+
377+ scripted_fn = torch .jit .script (fn )
378+ traced_fn = torch .jit .trace (fn , (input , weight , bias ))
379+
380+ self .assertEqual (scripted_fn (input , weight , bias ), result )
381+ self .assertEqual (traced_fn (input , weight , bias ), result )
382+
383+
367384if __name__ == '__main__' :
368385 torch .manual_seed (2020 )
369386 core .enable_auto_dnnl ()
0 commit comments