@@ -30,7 +30,10 @@ class TestQTensor:
30
30
)
31
31
@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
32
32
@pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ])
33
- def test_qtensor (self , num_bits , block_sizes , device , input_dtype ):
33
+ @pytest .mark .parametrize (
34
+ ("input_shape" , "check_memory" ), [((256 , 64 ), True ), ((256 , 32 ), False )]
35
+ ) # test
36
+ def test_qtensor (self , num_bits , block_sizes , device , input_dtype , input_shape , check_memory ):
34
37
nf4_attr_cfg = QuantizerAttributeConfig (
35
38
num_bits = num_bits ,
36
39
block_sizes = block_sizes ,
@@ -40,7 +43,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
40
43
41
44
# Original tensor
42
45
base_mem = torch .cuda .memory_allocated ("cuda" )
43
- x = torch .rand (256 , 64 ).to (device ).to (dtype = input_dtype )
46
+ x = torch .rand (input_shape ).to (device ).to (dtype = input_dtype )
44
47
x_allocated = torch .cuda .memory_allocated ("cuda" )
45
48
bf16_mem_usage = x_allocated - base_mem
46
49
@@ -51,7 +54,7 @@ def test_qtensor(self, num_bits, block_sizes, device, input_dtype):
51
54
nf4_mem_usage = nf4_x_allocated - base_mem
52
55
53
56
# Check the memory saving
54
- if bf16_mem_usage > 0 :
57
+ if bf16_mem_usage > 0 and check_memory :
55
58
assert (nf4_mem_usage ) / bf16_mem_usage < 0.3
56
59
57
60
# De-quantize to origin dtype
0 commit comments