Skip to content

Commit 9c78fb9

Browse files
dmpotspytorchmergebot
authored andcommitted
Fix assertion failure in gemm template lowering (pytorch#146353)
Summary: This commit fixes a crash in the gemm template lowering caused by hitting an [assert](https://github.com/pytorch/pytorch/blob/fd515e4f59bfa0ac9faa5185b7a02f3222c4cd08/torch/_inductor/codegen/common.py#L1181) that a buffer was previously removed. The assert triggers because in the first gemm lowering we use a local accumulation buffer, which causes the original buffer name to be added to the `removed_buffers` set. Then in the next gemm lowering we use the global buffer for accumulation, but that buffer name is already in the `removed_buffers` set. The fix is to add a unique suffix to the buffer name to avoid triggering the assert from different gemm lowerings. Differential Revision: D68814625 Pull Request resolved: pytorch#146353 Approved by: https://github.com/leslie-fang-intel, https://github.com/frost-intel, https://github.com/hl475
1 parent 6cb2f73 commit 9c78fb9

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

test/inductor/test_cpu_select_algorithm.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,6 +2299,63 @@ def forward(self, x, w):
22992299
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
23002300
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
23012301

2302+
@patches
2303+
@torch.no_grad
2304+
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
2305+
@set_num_threads(1) # avoid k_slicing to make the test deterministic
2306+
@parametrize(
2307+
"out_features1",
2308+
(
2309+
8,
2310+
16,
2311+
24,
2312+
32,
2313+
48,
2314+
),
2315+
)
2316+
@dtypes(torch.float)
2317+
def test_local_and_global_accumulator(self, out_features1, dtype):
2318+
batch_size = 256
2319+
in_features = 64
2320+
out_features = 129
2321+
in_features1 = 128
2322+
bias = True
2323+
try:
2324+
try:
2325+
from . import test_aot_inductor_utils
2326+
except ImportError:
2327+
import test_aot_inductor_utils
2328+
except Exception:
2329+
# skip this UT if import failed
2330+
return
2331+
2332+
class M(torch.nn.Module):
2333+
def __init__(self):
2334+
super().__init__()
2335+
2336+
self.linear = torch.nn.Linear(in_features, out_features, bias)
2337+
self.linear1 = torch.nn.Linear(in_features1, out_features1, bias)
2338+
2339+
def forward(self, x):
2340+
y = self.linear(x)
2341+
view = torch.ops.aten.view.default(y, [-1, in_features1])
2342+
return self.linear1(view)
2343+
2344+
counters.clear()
2345+
x = torch.randn(batch_size, in_features).to(dtype=dtype)
2346+
mod = M().to(dtype=dtype).eval()
2347+
with verify(dtype) as (atol, rtol), torch.no_grad():
2348+
expected = mod(
2349+
x,
2350+
)
2351+
actual = test_aot_inductor_utils.AOTIRunnerUtil.run(
2352+
"cpu",
2353+
mod,
2354+
(x,),
2355+
)
2356+
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
2357+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
2358+
23022359

23032360
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
23042361
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):

torch/_inductor/codegen/cpp_template.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ def __init__(
3434
) -> None:
3535
super().__init__(name)
3636
self.input_nodes = input_nodes
37+
self.index = next(self.index_counter)
3738
self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer(
38-
name="buf_out", layout=layout
39+
name=f"buf_out{self.index}", layout=layout
3940
)
4041
self.layout = layout
4142
self.num_threads = num_threads
@@ -75,7 +76,7 @@ def generate(self, **kwargs):
7576
# since in cpp kernel, we bind it to C long
7677
extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args)
7778

78-
kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}"
79+
kernel_hash_name = f"cpp_{self.name}_{self.index}"
7980

8081
# Create the BenchmarkRequest for CPP
8182
bmreq = CppBenchmarkRequest(

0 commit comments

Comments
 (0)