Skip to content

Commit 902ee15

Browse files
Sync changes from NV compiler.py (#5122)
1. renamed `threads_per_warp` to `warp_size` 2. changed some passes under `gluon_to_ttgir` from `ttgpuir` to `gluon` Signed-off-by: Whitney Tsang <[email protected]> Co-authored-by: Ilya Enkovich <[email protected]>
1 parent 6704582 commit 902ee15

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

python/test/unit/language/test_core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,8 +2393,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
23932393
# triton result
23942394
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
23952395
if is_xpu():
2396-
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas, num_warps=num_warps,
2397-
threads_per_warp=threads_per_warp)
2396+
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas, num_warps=num_warps, warp_size=threads_per_warp)
23982397
else:
23992398
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas)
24002399
z_tri = to_numpy(z_tri)
@@ -2527,7 +2526,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
25272526
kern_kwargs = {}
25282527
if is_xpu():
25292528
kern_kwargs['num_warps'] = num_warps
2530-
kern_kwargs['threads_per_warp'] = threads_per_warp
2529+
kern_kwargs['warp_size'] = threads_per_warp
25312530
if axis is not None and axis >= len(shape):
25322531
with pytest.raises(triton.TritonError):
25332532
kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis,

python/tutorials/02-fused-softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def allocated_slm_size(size_smem):
160160

161161
# pre-compile kernel to get register usage and compute thread occupancy.
162162
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
163-
threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
163+
warp_size=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
164164
kernel._init_handles()
165165
size_smem = kernel.metadata.shared
166166
num_programs = occupancy(num_warps, size_smem)

third_party/intel/backend/compiler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class XPUOptions:
2121
num_ctas: int = 1
2222
num_stages: int = 2
2323
cluster_dims: tuple = (1, 1, 1)
24-
threads_per_warp: int = 32
24+
warp_size: int = 32
2525
optimize_epilogue: bool = False
2626
enable_fp_fusion: bool = True
2727
launch_cooperative_grid: bool = False
@@ -177,10 +177,10 @@ def load_dialects(self, ctx):
177177

178178
@staticmethod
179179
def validate_options(opt, properties):
180-
# Check threads_per_warp and num_threads are within limits.
181-
if opt.threads_per_warp not in properties['sub_group_sizes']:
180+
# Check warp_size and num_threads are within limits.
181+
if opt.warp_size not in properties['sub_group_sizes']:
182182
raise ValueError(
183-
f"threads_per_warp={opt.threads_per_warp} is unsupported for the target (supported values are {properties['sub_group_sizes']})"
183+
f"warp_size={opt.warp_size} is unsupported for the target (supported values are {properties['sub_group_sizes']})"
184184
)
185185
if opt.num_warps > properties['max_num_sub_groups']:
186186
raise ValueError(
@@ -197,7 +197,7 @@ def annotate_module(mod, properties, opt, target_arch):
197197
module_opts.support_sg_2d_block = properties["has_subgroup_2d_block_io"]
198198
module_opts.support_dpas = properties["has_subgroup_matrix_multiply_accumulate"]
199199
module_opts.support_bf16_conversion = properties["has_bfloat16_conversions"]
200-
module_opts.threads_per_warp = opt.threads_per_warp
200+
module_opts.threads_per_warp = opt.warp_size
201201
module_opts.target_arch = target_arch
202202
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
203203
pm.run(mod)
@@ -241,8 +241,8 @@ def make_ttgir(mod, metadata, opt, properties):
241241
# Annotate module with information required by subsequent transformations.
242242
XPUBackend.annotate_module(mod, properties, opt, "spir64")
243243

244-
# Overwrite the threads_per_warp option with the module annotation.
245-
opt.threads_per_warp = intel.get_threads_per_warp(mod)
244+
# Overwrite the warp_size option with the module annotation.
245+
opt.warp_size = intel.get_threads_per_warp(mod)
246246
XPUBackend.validate_options(opt, properties)
247247

248248
if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
@@ -251,7 +251,7 @@ def make_ttgir(mod, metadata, opt, properties):
251251

252252
pm = ir.pass_manager(mod.context)
253253
pm.enable_debug()
254-
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
254+
passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.warp_size, opt.num_ctas)
255255
# optimize TTGIR
256256
intel.passes.ttgpuir.add_coalesce(pm)
257257
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
@@ -296,11 +296,11 @@ def gluon_to_ttgir(self, src, metadata, options):
296296
pm = ir.pass_manager(mod.context)
297297
pm.enable_debug()
298298

299-
passes.ttgpuir.add_inliner(pm)
299+
passes.gluon.add_inliner(pm)
300300
passes.gluon.add_resolve_auto_encodings(pm)
301301
passes.common.add_sccp(pm)
302302
passes.ttir.add_loop_aware_cse(pm)
303-
passes.ttgpuir.add_canonicalizer(pm)
303+
passes.gluon.add_canonicalizer(pm)
304304
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
305305

306306
pm.run(mod)

0 commit comments

Comments
 (0)