@@ -1378,6 +1378,127 @@ def _chebyshev_kernel_make_precompiler(x: torch.Tensor, w: torch.Tensor):
1378
1378
return make_precompiler(_chebyshev_kernel_kernel)(x, w, out, out.stride(0), out.stride(1), w.stride(0), w.stride(1), x.stride(0), x.stride(1), B, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""" ,
1379
1379
)
1380
1380
1381
+ def test_loop_unroll1 (self ):
1382
+ @helion .kernel ()
1383
+ def fn (x : torch .Tensor ) -> torch .Tensor :
1384
+ out = torch .zeros_like (x )
1385
+ for tile in hl .tile (x .size ()):
1386
+ out [tile ] = x [tile ]
1387
+ for i in [1 , 2 , 3 ]:
1388
+ out [tile ] += i
1389
+ return out
1390
+
1391
+ x = torch .randn (4 , device = DEVICE )
1392
+ code , output = code_and_output (fn , (x ,))
1393
+ torch .testing .assert_close (output , x + 6 )
1394
+ self .assertExpectedInline (
1395
+ code ,
1396
+ """\
1397
+ from __future__ import annotations
1398
+
1399
+ import torch
1400
+ import triton
1401
+ import triton.language as tl
1402
+
1403
+ @triton.jit
1404
+ def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1405
+ pid_0 = tl.program_id(0)
1406
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1407
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1408
+ mask_0 = indices_0 < x_size_0
1409
+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1410
+ tl.store(out + indices_0 * out_stride_0, load, mask_0)
1411
+ load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1412
+ v_0 = 1.0
1413
+ v_1 = load_1 + v_0
1414
+ tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1415
+ load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1416
+ v_2 = 2.0
1417
+ v_3 = load_2 + v_2
1418
+ tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
1419
+ load_3 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1420
+ v_4 = 3.0
1421
+ v_5 = load_3 + v_4
1422
+ tl.store(out + indices_0 * out_stride_0, v_5, mask_0)
1423
+
1424
+ def fn(x: torch.Tensor):
1425
+ out = torch.zeros_like(x)
1426
+ _BLOCK_SIZE_0 = 4
1427
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1428
+ return out
1429
+
1430
+ def _fn_make_precompiler(x: torch.Tensor):
1431
+ out = torch.zeros_like(x)
1432
+ _BLOCK_SIZE_0 = 4
1433
+ from helion.runtime.precompile_shim import make_precompiler
1434
+ return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
1435
+ )
1436
+
1437
+ def test_loop_unroll2 (self ):
1438
+ @helion .kernel ()
1439
+ def fn (x : torch .Tensor ) -> torch .Tensor :
1440
+ out = torch .zeros_like (x )
1441
+ a = 1
1442
+ b = 2
1443
+ c = 3
1444
+ for tile in hl .tile (x .size ()):
1445
+ out [tile ] = x [tile ]
1446
+ for i in (a , b , c ):
1447
+ out [tile ] += i
1448
+ return out
1449
+
1450
+ x = torch .randn (4 , device = DEVICE )
1451
+ code , output = code_and_output (fn , (x ,))
1452
+ torch .testing .assert_close (output , x + 6 )
1453
+ self .assertExpectedInline (
1454
+ code ,
1455
+ """\
1456
+ from __future__ import annotations
1457
+
1458
+ import torch
1459
+ import triton
1460
+ import triton.language as tl
1461
+
1462
+ @triton.jit
1463
+ def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1464
+ pid_0 = tl.program_id(0)
1465
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1466
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1467
+ mask_0 = indices_0 < x_size_0
1468
+ load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1469
+ tl.store(out + indices_0 * out_stride_0, load, mask_0)
1470
+ load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1471
+ v_0 = 1.0
1472
+ v_1 = load_1 + v_0
1473
+ tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1474
+ load_2 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1475
+ v_2 = 2.0
1476
+ v_3 = load_2 + v_2
1477
+ tl.store(out + indices_0 * out_stride_0, v_3, mask_0)
1478
+ load_3 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1479
+ v_4 = 3.0
1480
+ v_5 = load_3 + v_4
1481
+ tl.store(out + indices_0 * out_stride_0, v_5, mask_0)
1482
+
1483
+ def fn(x: torch.Tensor):
1484
+ out = torch.zeros_like(x)
1485
+ a = 1
1486
+ b = 2
1487
+ c = 3
1488
+ _BLOCK_SIZE_0 = 4
1489
+ _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1490
+ return out
1491
+
1492
+ def _fn_make_precompiler(x: torch.Tensor):
1493
+ out = torch.zeros_like(x)
1494
+ a = 1
1495
+ b = 2
1496
+ c = 3
1497
+ _BLOCK_SIZE_0 = 4
1498
+ from helion.runtime.precompile_shim import make_precompiler
1499
+ return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""" ,
1500
+ )
1501
+
1381
1502
1382
1503
if __name__ == "__main__" :
1383
1504
unittest .main ()
0 commit comments