File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -285,24 +285,22 @@ def test_linear_kbit_fp32_bias(device, module):
285285@pytest .mark .parametrize ("device" , get_available_devices ())
286286@pytest .mark .parametrize ("module" , module_dict .values (), ids = module_dict .keys ())
287287def test_kbit_backprop (device , module ):
288- if device == "cpu" :
289- pytest .xfail ("Test is not yet supported on CPU" )
290-
291288 b = 16
292289 dim1 = 36
293290 dim2 = 84
294291 # dim1 = 37
295292 # dim2 = 83
296293
297294 ref = nn .Sequential (* [torch .nn .Linear (dim1 , dim2 ), torch .nn .Linear (dim2 , 128 )])
298- # ref[1].weight.requires_grad = False
299295 torch .nn .init .kaiming_normal_ (ref [0 ].weight )
300296 torch .nn .init .kaiming_normal_ (ref [1 ].weight )
297+ ref [1 ].weight .requires_grad_ (False )
301298 kbit = nn .Sequential (* [torch .nn .Linear (dim1 , dim2 ), module (dim2 , 128 )])
302299 kbit [0 ].weight .detach ().copy_ (ref [0 ].weight )
303300 kbit [1 ].weight .detach ().copy_ (ref [1 ].weight )
304301 kbit [0 ].bias .detach ().copy_ (ref [0 ].bias )
305302 kbit [1 ].bias .detach ().copy_ (ref [1 ].bias )
303+ kbit [1 ].weight .requires_grad_ (False )
306304 ref = ref .half ().to (device )
307305 kbit = kbit .half ().to (device )
308306 kbit = kbit .half ().to (device )
You can’t perform that action at this time.
0 commit comments