Skip to content

Commit 676227a

Browse files
authored
[BENCH] Address and/or warnings in triton kernels (#6841)
1 parent 8fe51e2 commit 676227a

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

python/triton/compiler/code_generator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,8 +1300,15 @@ def visit_BoolOp(self, node: ast.BoolOp):
13001300
# expression so we do not append it to nontrivial_values.
13011301
else:
13021302
if value.type.is_block():
1303-
warnings.warn(
1304-
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead"
1303+
lineno = getattr(node, "lineno", None)
1304+
if lineno is not None:
1305+
lineno += self.begin_line
1306+
warnings.warn_explicit(
1307+
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
1308+
category=UserWarning,
1309+
filename=self.file_name,
1310+
lineno=lineno,
1311+
source=ast.unparse(node),
13051312
)
13061313
# not a constexpr so we must append it:
13071314
nontrivial_values.append(value)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def _compute_writeback_idx(
321321
is_src_active = (src_idxs != -1).to(tl.int32)
322322
has_one_active = tl.sum(is_src_active, axis=1) == 1
323323

324-
need_finalize_scatter = mask_m and not has_one_active
324+
need_finalize_scatter = mask_m & (~has_one_active)
325325
finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
326326
if finalize_scatter_count == 0:
327327
return

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
155155

156156
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
157157
mask_n = start_out + offs_outer < outer_dim
158-
full_mask_src = mask_src_quant and mask_n
158+
full_mask_src = mask_src_quant & mask_n
159159

160160
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
161-
full_mask_mxt = mask_mxt_quant and mask_n
161+
full_mask_mxt = mask_mxt_quant & mask_n
162162

163163
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32)
164-
full_scale_mask = scale_mask_k and mask_n
164+
full_scale_mask = scale_mask_k & mask_n
165165

166166
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
167167
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
@@ -219,13 +219,13 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr,
219219

220220
mask_outer = start_out + offs_outer < outer_dim
221221
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
222-
full_mask_out = mask_out_quant and mask_outer
222+
full_mask_out = mask_out_quant & mask_outer
223223

224224
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
225-
full_mask_src = mask_src_quant and mask_outer
225+
full_mask_src = mask_src_quant & mask_outer
226226

227227
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32)
228-
full_scale_mask = mask_scale and mask_outer
228+
full_scale_mask = mask_scale & mask_outer
229229

230230
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
231231
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer

python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale,
6666
if pid_n * BLOCK_N + BLOCK_N <= N:
6767
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
6868
else:
69-
packed_mask = mask_m[:, None] and packed_mask_n[None, :]
69+
packed_mask = mask_m[:, None] & packed_mask_n[None, :]
7070
a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.)
7171
a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
7272
# a gelu

0 commit comments

Comments
 (0)