@@ -66,7 +66,7 @@ def test_bench_matmul(batch, seq, model, hidden):
6666 torch .matmul (A , B .t ())
6767 torch .cuda .synchronize ()
6868 print (
69- f"pytorch fp16: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" ,
69+ f"pytorch fp16: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s" ,
7070 )
7171
7272 # torch.cuda.synchronize()
@@ -88,22 +88,24 @@ def test_bench_matmul(batch, seq, model, hidden):
8888 for i in range (iters ):
8989 bnb .matmul_4bit (A , B_nf4 .t (), quant_state = state_nf4 )
9090 torch .cuda .synchronize ()
91- print (f"bnb nf4: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" )
91+ print (f"bnb nf4: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s" )
9292
9393 torch .cuda .synchronize ()
9494 t0 = time .time ()
9595 for i in range (iters ):
9696 bnb .matmul_4bit (A , B_nf4_c .t (), quant_state = state_nf4_c )
9797 torch .cuda .synchronize ()
98- print (f"bnb nf4+DQ: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s" )
98+ print (
99+ f"bnb nf4+DQ: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
100+ )
99101
100102 torch .cuda .synchronize ()
101103 t0 = time .time ()
102104 for i in range (iters ):
103105 bnb .matmul (A , B )
104106 torch .cuda .synchronize ()
105107 print (
106- f"B -> CB (each iteration): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
108+ f"B -> CB (each iteration): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
107109 )
108110
109111 torch .cuda .synchronize ()
@@ -112,7 +114,7 @@ def test_bench_matmul(batch, seq, model, hidden):
112114 bnb .matmul (A , B , threshold = 6.0 )
113115 torch .cuda .synchronize ()
114116 print (
115- f"B -> CB + threshold: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
117+ f"B -> CB + threshold: [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
116118 )
117119
118120 CA , SCA , _ = F .int8_vectorwise_quant (A , threshold = 0.0 )
@@ -124,7 +126,7 @@ def test_bench_matmul(batch, seq, model, hidden):
124126 out32 = F .int8_linear_matmul (CA , CB )
125127 torch .cuda .synchronize ()
126128 print (
127- f"no overhead int8 [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
129+ f"no overhead int8 [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
128130 )
129131
130132 # C32A, SA = F.transform(CA, "col32")
@@ -183,7 +185,7 @@ def test_bench_matmul(batch, seq, model, hidden):
183185 linear8bit (A )
184186 torch .cuda .synchronize ()
185187 print (
186- f"bnb linear8bitlt (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
188+ f"bnb linear8bitlt (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
187189 )
188190
189191 linearMixedBit (A )
@@ -193,7 +195,7 @@ def test_bench_matmul(batch, seq, model, hidden):
193195 linearMixedBit (A )
194196 torch .cuda .synchronize ()
195197 print (
196- f"bnb linear8bitlt with threshold (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time ()- t0 :.4f} s"
198+ f"bnb linear8bitlt with threshold (eval): [{ batch } ,{ seq } ,{ model } ], [{ model } ,{ hidden } ]->[{ batch } ,{ seq } ,{ hidden } ]: { time .time () - t0 :.4f} s"
197199 )
198200
199201 # linear8bit_train(A)
0 commit comments