Skip to content

Commit 39456ed

Browse files
xuanzhang816pytorchmergebot
authored andcommitted
[PT2][memory] mutation size correctness (pytorch#157562)
Pull Request resolved: pytorch#157562 Approved by: https://github.com/yf225
1 parent a1dad2f commit 39456ed

File tree

3 files changed

+96
-18
lines changed

3 files changed

+96
-18
lines changed

test/distributed/test_compute_comm_reordering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def func(a):
179179
.check("extern_kernels.mm")
180180
.check("triton_poi_fused_relu")
181181
.check("torch.ops._c10d_functional.all_reduce_.default")
182-
.check("extern_kernels.mm")
183182
.check("torch.ops._c10d_functional.wait_tensor.default")
184183
.check("extern_kernels.mm")
184+
.check("extern_kernels.mm")
185185
.run(code)
186186
)
187187
out = compiled(inputs)

test/inductor/test_memory.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1212

1313

14+
try:
15+
import triton
16+
from triton import language as tl
17+
18+
TRITON_AVAILABLE = True
19+
except ImportError:
20+
TRITON_AVAILABLE = False
21+
22+
1423
class Foo(torch.nn.Module):
1524
"""
1625
The default compiled graph is
@@ -203,6 +212,80 @@ def reorder_with_only_dfs(
203212
outp = compiled_model(self.inputs)
204213
self.assertTrue(same(outp, outp_corr))
205214

215+
@mock.patch.object(config, "allow_buffer_reuse", False)
216+
@unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available")
217+
def test_mutation_size_propogation(self):
218+
"""
219+
This tests correct size propogation in the case of mutations.
220+
In this example, buf1 is a mutation of buf0; we should have:
221+
* buf0: has size_alloc 2048 and size_free 0;
222+
* buf1: has size_alloc 0 and size_free 2048.
223+
This is because
224+
- when buf1 is created, no additional memory is used; and
225+
- the 2048 bytes of memory can only be released when buf1 is freed.
226+
Similar arguments for buf2 and buf3, buf4 and buf5, etc.
227+
"""
228+
229+
# using triton custom kernel to creat small example with mutations
230+
@triton.jit
231+
def convert_to_bf16_kernel(
232+
input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
233+
):
234+
pid = tl.program_id(axis=0)
235+
block_start = pid * BLOCK_SIZE
236+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
237+
mask = offsets < n_elements
238+
x = tl.load(input_ptr + offsets, mask=mask)
239+
x_bf16 = x.to(tl.bfloat16)
240+
tl.store(output_ptr + offsets, x_bf16, mask=mask)
241+
242+
def convert_to_bf16(x):
243+
output = torch.empty_like(x, dtype=torch.bfloat16)
244+
n_elements = x.numel()
245+
BLOCK_SIZE = 1024
246+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
247+
convert_to_bf16_kernel[grid](
248+
x.flatten(), output.flatten(), n_elements, BLOCK_SIZE
249+
)
250+
return output.view(x.shape)
251+
252+
# create a custom function to record the buffer size information
253+
buffer_info = {}
254+
og_method = memory.assign_memory_planning_info_for_scheduler_buffers
255+
256+
def assign_memory_planning_info_for_scheduler_buffers_with_records(
257+
nodes, name_to_buf
258+
):
259+
og_method(nodes, name_to_buf)
260+
for buf_name, buf in name_to_buf.items():
261+
buffer_info[buf_name] = (
262+
buf.mpi_buffer.size_alloc,
263+
buf.mpi_buffer.size_free,
264+
)
265+
266+
# test example and checks
267+
def f(a, p):
268+
for e in a:
269+
e = convert_to_bf16(e)
270+
p = p @ e
271+
return p
272+
273+
a = [torch.randn(32, 32, device=GPU_TYPE) for _ in range(4)]
274+
p = torch.ones(a[0].size(), dtype=torch.bfloat16, device=GPU_TYPE)
275+
276+
with mock.patch.object(
277+
memory,
278+
"assign_memory_planning_info_for_scheduler_buffers",
279+
assign_memory_planning_info_for_scheduler_buffers_with_records,
280+
):
281+
f_compiled = torch.compile(f)
282+
f_compiled(a, p)
283+
for buf_name in ["buf0", "buf2", "buf4", "buf6"]:
284+
self.assertEqual(buffer_info[buf_name], (2048, 0))
285+
286+
for buf_name in ["buf1", "buf3", "buf5", "buf7"]:
287+
self.assertEqual(buffer_info[buf_name], (0, 2048))
288+
206289
@unittest.skipIf(
207290
not torch.cuda.is_available()
208291
or torch.cuda.get_device_properties().total_memory < int(1e10),

torch/_inductor/memory.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.utils._ordered_set import OrderedSet
1111

1212
from .ir import MultiOutputLayout, NoneLayout
13-
from .utils import get_dtype_size, is_wait
13+
from .utils import get_dtype_size
1414
from .virtualized import V
1515

1616

@@ -147,23 +147,18 @@ def _compute_and_update_buf_size(
147147
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
148148
) -> int:
149149
if isinstance(sched_buf.node.layout, NoneLayout):
150-
_size = 0
151-
# for a wait tensor op, its schedulerBuffer NoneLayout layout. However,
152-
# the schedulerBuffer is treated as a mutation of the collective output
153-
# so it needs to inherit the size of the collectives
154-
if (
155-
sched_buf.defining_op
156-
and is_wait(sched_buf.defining_op.node)
157-
and sched_buf.get_mutations()
158-
):
150+
# mutations should inherit the size of the mutated buffer
151+
if sched_buf.get_mutations():
159152
mutated_buf_name = sched_buf.get_mutations()[0]
160-
_size = (
161-
sched_buf_to_size[mutated_buf_name][1]
162-
if mutated_buf_name in sched_buf_to_size
163-
else 0
164-
)
165-
sched_buf_to_size[sched_buf.get_name()] = (_size, _size)
166-
return _size
153+
if mutated_buf_name in sched_buf_to_size:
154+
(_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name]
155+
else:
156+
(_size_alloc, _size_free) = (0, 0)
157+
sched_buf_to_size[sched_buf.get_name()] = (0, _size_free)
158+
sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0)
159+
else:
160+
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
161+
return 0
167162
elif isinstance(sched_buf.node.layout, MultiOutputLayout):
168163
size_alloc = 0
169164
for user in sched_buf.users:

0 commit comments

Comments
 (0)