Skip to content

Commit d3efd73

Browse files
Revert "[cutlass backend][BE][ez] Make matmul layouts be row x column (pytorch#156656)"
This reverts commit 84c588e. Reverted pytorch#156656 on behalf of https://github.com/henrylhtsang due to breaking fbcode A100 tests ([comment](pytorch#156656 (comment)))
1 parent 3684be0 commit d3efd73

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_cutlass_backend_subproc_mm(self):
261261
M, N, K = 4096, 2048, 25728
262262

263263
a = torch.randn(M, K).cuda().half()
264-
b = torch.randn(N, K).cuda().half().t()
264+
b = torch.randn(K, N).cuda().half()
265265

266266
with config.patch(
267267
{
@@ -289,7 +289,7 @@ def test_cutlass_backend_subproc_addmm(self, shape_combo):
289289
M, N, K = 4096, 2048, 25728
290290

291291
a = torch.randn(M, K).cuda().half()
292-
b = torch.randn(N, K).cuda().half().t()
292+
b = torch.randn(K, N).cuda().half()
293293

294294
x_shapes = [
295295
(M, N),
@@ -326,7 +326,7 @@ def test_cutlass_backend_subproc_bmm(self):
326326
B, M, N, K = 10, 4096, 2048, 25728
327327

328328
a = torch.randn(B, M, K).cuda().half()
329-
b = torch.randn(B, N, K).cuda().half().permute(0, 2, 1)
329+
b = torch.randn(B, K, N).cuda().half()
330330

331331
with config.patch(
332332
{
@@ -358,8 +358,8 @@ def forward(self, a, b, c):
358358

359359
model = MyModel()
360360
a = torch.randn(128, 16).cuda().half()
361-
b = torch.randn(128, 16).cuda().half().t()
362-
c = torch.randn(512, 16).cuda().half().t()
361+
b = torch.randn(16, 128).cuda().half()
362+
c = torch.randn(16, 512).cuda().half()
363363

364364
with config.patch(
365365
{
@@ -400,8 +400,8 @@ def forward(self, a, b, c):
400400

401401
model = MyModel()
402402
a = torch.randn(128, 16).cuda().half()
403-
b = torch.randn(128, 16).cuda().half().t()
404-
c = torch.randn(512, 16).cuda().half().t()
403+
b = torch.randn(16, 128).cuda().half()
404+
c = torch.randn(16, 512).cuda().half()
405405

406406
with config.patch(
407407
{
@@ -465,7 +465,7 @@ def forward(self, a, b):
465465
model = MyModel().cuda()
466466

467467
inputs = [
468-
(torch.randn(M, K).cuda().to(dtype), torch.randn(N, K).cuda().to(dtype).t())
468+
(torch.randn(M, K).cuda().to(dtype), torch.randn(K, N).cuda().to(dtype))
469469
for (M, N, K) in shapes
470470
]
471471

@@ -633,7 +633,7 @@ def forward(self, x, a, b):
633633
(
634634
torch.randn(x_shape(M, N)).cuda().to(dtype),
635635
torch.randn(M, K).cuda().to(dtype),
636-
torch.randn(N, K).cuda().to(dtype).t(),
636+
torch.randn(K, N).cuda().to(dtype),
637637
)
638638
for (M, N, K) in shapes
639639
]
@@ -744,7 +744,7 @@ def mm(a, b):
744744
return a @ b
745745

746746
a = torch.randn(128, 16).cuda().half()
747-
b = torch.randn(128, 16).cuda().half().t()
747+
b = torch.randn(16, 128).cuda().half()
748748

749749
with config.patch(
750750
{
@@ -770,7 +770,7 @@ def mm(a, b):
770770
),
771771
):
772772
a = torch.randn(M, K).cuda().half()
773-
b = torch.randn(N, K).cuda().half().t()
773+
b = torch.randn(K, N).cuda().half()
774774
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
775775
Y = mm(a, b)
776776
# we need relaxed numerical limits due to the sheer size of the
@@ -935,7 +935,7 @@ def forward(self, x, w):
935935
}
936936

937937
x = torch.randn(M, K).cuda().half()
938-
w = torch.randn(N, K).cuda().half().t()
938+
w = torch.randn(K, N).cuda().half()
939939

940940
actual = AOTIRunnerUtil.run(
941941
model,
@@ -973,7 +973,7 @@ def forward(self, x, w):
973973
}
974974

975975
x = torch.randn(M, K).cuda().half()
976-
w = torch.randn(N, K).cuda().half().t()
976+
w = torch.randn(K, N).cuda().half()
977977

978978
actual = AOTIRunnerUtil.run(
979979
model,
@@ -1003,7 +1003,7 @@ def forward(self, x, w):
10031003
M, N, K = 200, 5216, 10_432
10041004

10051005
x = torch.randn(M, K).cuda().half()
1006-
w = torch.randn(N, K).cuda().half().t()
1006+
w = torch.randn(K, N).cuda().half()
10071007

10081008
actual = AOTIRunnerUtil.run(
10091009
model,
@@ -1032,7 +1032,7 @@ def mm(a, b):
10321032
mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).cuda().half()
10331033
a = torch.rand(m, k).cuda().half() * mask
10341034
a_sparse = to_sparse_semi_structured(a)
1035-
b = torch.rand(n, k).cuda().half().t()
1035+
b = torch.rand(k, n).cuda().half()
10361036

10371037
with config.patch(
10381038
{
@@ -1335,7 +1335,7 @@ def test_cutlass_presets(
13351335

13361336
M, N, K = (128, 128, 16)
13371337
A = torch.randn(M, K).cuda().half()
1338-
B = torch.randn(N, K).cuda().half().t()
1338+
B = torch.randn(K, N).cuda().half()
13391339

13401340
def select_no_algorithm(*args, **kwargs):
13411341
raise NoValidChoicesError

0 commit comments

Comments
 (0)