@@ -261,7 +261,7 @@ def test_cutlass_backend_subproc_mm(self):
261261 M , N , K = 4096 , 2048 , 25728
262262
263263 a = torch .randn (M , K ).cuda ().half ()
264- b = torch .randn (N , K ).cuda ().half (). t ()
264+ b = torch .randn (K , N ).cuda ().half ()
265265
266266 with config .patch (
267267 {
@@ -289,7 +289,7 @@ def test_cutlass_backend_subproc_addmm(self, shape_combo):
289289 M , N , K = 4096 , 2048 , 25728
290290
291291 a = torch .randn (M , K ).cuda ().half ()
292- b = torch .randn (N , K ).cuda ().half (). t ()
292+ b = torch .randn (K , N ).cuda ().half ()
293293
294294 x_shapes = [
295295 (M , N ),
@@ -326,7 +326,7 @@ def test_cutlass_backend_subproc_bmm(self):
326326 B , M , N , K = 10 , 4096 , 2048 , 25728
327327
328328 a = torch .randn (B , M , K ).cuda ().half ()
329- b = torch .randn (B , N , K ).cuda ().half (). permute ( 0 , 2 , 1 )
329+ b = torch .randn (B , K , N ).cuda ().half ()
330330
331331 with config .patch (
332332 {
@@ -358,8 +358,8 @@ def forward(self, a, b, c):
358358
359359 model = MyModel ()
360360 a = torch .randn (128 , 16 ).cuda ().half ()
361- b = torch .randn (128 , 16 ).cuda ().half (). t ()
362- c = torch .randn (512 , 16 ).cuda ().half (). t ()
361+ b = torch .randn (16 , 128 ).cuda ().half ()
362+ c = torch .randn (16 , 512 ).cuda ().half ()
363363
364364 with config .patch (
365365 {
@@ -400,8 +400,8 @@ def forward(self, a, b, c):
400400
401401 model = MyModel ()
402402 a = torch .randn (128 , 16 ).cuda ().half ()
403- b = torch .randn (128 , 16 ).cuda ().half (). t ()
404- c = torch .randn (512 , 16 ).cuda ().half (). t ()
403+ b = torch .randn (16 , 128 ).cuda ().half ()
404+ c = torch .randn (16 , 512 ).cuda ().half ()
405405
406406 with config .patch (
407407 {
@@ -465,7 +465,7 @@ def forward(self, a, b):
465465 model = MyModel ().cuda ()
466466
467467 inputs = [
468- (torch .randn (M , K ).cuda ().to (dtype ), torch .randn (N , K ).cuda ().to (dtype ). t ( ))
468+ (torch .randn (M , K ).cuda ().to (dtype ), torch .randn (K , N ).cuda ().to (dtype ))
469469 for (M , N , K ) in shapes
470470 ]
471471
@@ -633,7 +633,7 @@ def forward(self, x, a, b):
633633 (
634634 torch .randn (x_shape (M , N )).cuda ().to (dtype ),
635635 torch .randn (M , K ).cuda ().to (dtype ),
636- torch .randn (N , K ).cuda ().to (dtype ). t ( ),
636+ torch .randn (K , N ).cuda ().to (dtype ),
637637 )
638638 for (M , N , K ) in shapes
639639 ]
@@ -744,7 +744,7 @@ def mm(a, b):
744744 return a @ b
745745
746746 a = torch .randn (128 , 16 ).cuda ().half ()
747- b = torch .randn (128 , 16 ).cuda ().half (). t ()
747+ b = torch .randn (16 , 128 ).cuda ().half ()
748748
749749 with config .patch (
750750 {
@@ -770,7 +770,7 @@ def mm(a, b):
770770 ),
771771 ):
772772 a = torch .randn (M , K ).cuda ().half ()
773- b = torch .randn (N , K ).cuda ().half (). t ()
773+ b = torch .randn (K , N ).cuda ().half ()
774774 Y_compiled = torch .compile (mm , dynamic = dynamic )(a , b )
775775 Y = mm (a , b )
776776 # we need relaxed numerical limits due to the sheer size of the
@@ -935,7 +935,7 @@ def forward(self, x, w):
935935 }
936936
937937 x = torch .randn (M , K ).cuda ().half ()
938- w = torch .randn (N , K ).cuda ().half (). t ()
938+ w = torch .randn (K , N ).cuda ().half ()
939939
940940 actual = AOTIRunnerUtil .run (
941941 model ,
@@ -973,7 +973,7 @@ def forward(self, x, w):
973973 }
974974
975975 x = torch .randn (M , K ).cuda ().half ()
976- w = torch .randn (N , K ).cuda ().half (). t ()
976+ w = torch .randn (K , N ).cuda ().half ()
977977
978978 actual = AOTIRunnerUtil .run (
979979 model ,
@@ -1003,7 +1003,7 @@ def forward(self, x, w):
10031003 M , N , K = 200 , 5216 , 10_432
10041004
10051005 x = torch .randn (M , K ).cuda ().half ()
1006- w = torch .randn (N , K ).cuda ().half (). t ()
1006+ w = torch .randn (K , N ).cuda ().half ()
10071007
10081008 actual = AOTIRunnerUtil .run (
10091009 model ,
@@ -1032,7 +1032,7 @@ def mm(a, b):
10321032 mask = torch .tensor ([0 , 0 , 1 , 1 ]).tile (m , k // 4 ).cuda ().half ()
10331033 a = torch .rand (m , k ).cuda ().half () * mask
10341034 a_sparse = to_sparse_semi_structured (a )
1035- b = torch .rand (n , k ).cuda ().half (). t ()
1035+ b = torch .rand (k , n ).cuda ().half ()
10361036
10371037 with config .patch (
10381038 {
@@ -1335,7 +1335,7 @@ def test_cutlass_presets(
13351335
13361336 M , N , K = (128 , 128 , 16 )
13371337 A = torch .randn (M , K ).cuda ().half ()
1338- B = torch .randn (N , K ).cuda ().half (). t ()
1338+ B = torch .randn (K , N ).cuda ().half ()
13391339
13401340 def select_no_algorithm (* args , ** kwargs ):
13411341 raise NoValidChoicesError
0 commit comments