|
11 | 11 | from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU |
12 | 12 |
|
13 | 13 |
|
| 14 | +try: |
| 15 | + import triton |
| 16 | + from triton import language as tl |
| 17 | + |
| 18 | + TRITON_AVAILABLE = True |
| 19 | +except ImportError: |
| 20 | + TRITON_AVAILABLE = False |
| 21 | + |
| 22 | + |
14 | 23 | class Foo(torch.nn.Module): |
15 | 24 | """ |
16 | 25 | The default compiled graph is |
@@ -203,6 +212,80 @@ def reorder_with_only_dfs( |
203 | 212 | outp = compiled_model(self.inputs) |
204 | 213 | self.assertTrue(same(outp, outp_corr)) |
205 | 214 |
|
| 215 | + @mock.patch.object(config, "allow_buffer_reuse", False) |
| 216 | + @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") |
| 217 | + def test_mutation_size_propogation(self): |
| 218 | + """ |
| 219 | + This tests correct size propogation in the case of mutations. |
| 220 | + In this example, buf1 is a mutation of buf0; we should have: |
| 221 | + * buf0: has size_alloc 2048 and size_free 0; |
| 222 | + * buf1: has size_alloc 0 and size_free 2048. |
| 223 | + This is because |
| 224 | + - when buf1 is created, no additional memory is used; and |
| 225 | + - the 2048 bytes of memory can only be released when buf1 is freed. |
| 226 | + Similar arguments for buf2 and buf3, buf4 and buf5, etc. |
| 227 | + """ |
| 228 | + |
| 229 | + # using triton custom kernel to creat small example with mutations |
| 230 | + @triton.jit |
| 231 | + def convert_to_bf16_kernel( |
| 232 | + input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr |
| 233 | + ): |
| 234 | + pid = tl.program_id(axis=0) |
| 235 | + block_start = pid * BLOCK_SIZE |
| 236 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 237 | + mask = offsets < n_elements |
| 238 | + x = tl.load(input_ptr + offsets, mask=mask) |
| 239 | + x_bf16 = x.to(tl.bfloat16) |
| 240 | + tl.store(output_ptr + offsets, x_bf16, mask=mask) |
| 241 | + |
| 242 | + def convert_to_bf16(x): |
| 243 | + output = torch.empty_like(x, dtype=torch.bfloat16) |
| 244 | + n_elements = x.numel() |
| 245 | + BLOCK_SIZE = 1024 |
| 246 | + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) |
| 247 | + convert_to_bf16_kernel[grid]( |
| 248 | + x.flatten(), output.flatten(), n_elements, BLOCK_SIZE |
| 249 | + ) |
| 250 | + return output.view(x.shape) |
| 251 | + |
| 252 | + # create a custom function to record the buffer size information |
| 253 | + buffer_info = {} |
| 254 | + og_method = memory.assign_memory_planning_info_for_scheduler_buffers |
| 255 | + |
| 256 | + def assign_memory_planning_info_for_scheduler_buffers_with_records( |
| 257 | + nodes, name_to_buf |
| 258 | + ): |
| 259 | + og_method(nodes, name_to_buf) |
| 260 | + for buf_name, buf in name_to_buf.items(): |
| 261 | + buffer_info[buf_name] = ( |
| 262 | + buf.mpi_buffer.size_alloc, |
| 263 | + buf.mpi_buffer.size_free, |
| 264 | + ) |
| 265 | + |
| 266 | + # test example and checks |
| 267 | + def f(a, p): |
| 268 | + for e in a: |
| 269 | + e = convert_to_bf16(e) |
| 270 | + p = p @ e |
| 271 | + return p |
| 272 | + |
| 273 | + a = [torch.randn(32, 32, device=GPU_TYPE) for _ in range(4)] |
| 274 | + p = torch.ones(a[0].size(), dtype=torch.bfloat16, device=GPU_TYPE) |
| 275 | + |
| 276 | + with mock.patch.object( |
| 277 | + memory, |
| 278 | + "assign_memory_planning_info_for_scheduler_buffers", |
| 279 | + assign_memory_planning_info_for_scheduler_buffers_with_records, |
| 280 | + ): |
| 281 | + f_compiled = torch.compile(f) |
| 282 | + f_compiled(a, p) |
| 283 | + for buf_name in ["buf0", "buf2", "buf4", "buf6"]: |
| 284 | + self.assertEqual(buffer_info[buf_name], (2048, 0)) |
| 285 | + |
| 286 | + for buf_name in ["buf1", "buf3", "buf5", "buf7"]: |
| 287 | + self.assertEqual(buffer_info[buf_name], (0, 2048)) |
| 288 | + |
206 | 289 | @unittest.skipIf( |
207 | 290 | not torch.cuda.is_available() |
208 | 291 | or torch.cuda.get_device_properties().total_memory < int(1e10), |
|
0 commit comments