@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
440440 dim1 = 64
441441
442442 with pytest .warns (UserWarning , match = r"inference or training" ):
443- net = nn .Sequential (
444- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
445- )
443+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
446444 net = net .to (device )
447445 inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
448446 net (inp )
449447 with pytest .warns (UserWarning , match = r"inference." ):
450- net = nn .Sequential (
451- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
452- )
448+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
453449 net = net .to (device )
454450 inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
455451 net (inp )
456452
457453 with pytest .warns (UserWarning ) as record :
458- net = nn .Sequential (
459- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
460- )
454+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
461455 net = net .to (device )
462456 inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
463457 net (inp )
464458
465- net = nn .Sequential (
466- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
467- )
459+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
468460 net = net .to (device )
469461 inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
470462 net (inp )
0 commit comments