Skip to content

Commit cd4527d

Browse files
Merge commit '7d3bf12eecc81346528f0072aedf69a8a5af41e5'
2 parents 71a87ff + 7d3bf12 commit cd4527d

37 files changed

+567
-497
lines changed

lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,12 @@ LogicalResult inferAutoLayouts(FuncOp func) {
121121
} else {
122122
auto srcEncoding = inferSrcEncoding(definingOp, enc);
123123
if (srcEncoding) {
124-
if (failed(updateEncoding(
125-
llvm::to_vector_of<Value>(definingOp->getOperands()),
126-
srcEncoding)))
124+
llvm::SmallVector<Value> tensorOperands;
125+
for (auto operand : definingOp->getOperands())
126+
if (isa<RankedTensorType>(operand.getType()))
127+
tensorOperands.push_back(operand);
128+
129+
if (failed(updateEncoding(tensorOperands, srcEncoding)))
127130
return failure();
128131
}
129132
}

python/test/gluon/test_frontend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,3 +1233,28 @@ def test_auto_layout():
12331233
z = x + y
12341234
# CHECK: (tensor<16x8xi32, #gluon.auto_encoding>) -> tensor<16xi32, #gluon.auto_encoding
12351235
ttgl.sum(z, axis=1)
1236+
1237+
# CHECK: tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #gluon.auto_encoding>
1238+
ttgl.arange(0, 32)
1239+
1240+
1241+
@filecheck_test
1242+
@gluon.jit
1243+
def test_auto_layout_broadcast():
1244+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked
1245+
# CHECK: [[X:%.*]] = arith.constant dense<1> : tensor<16x1xi32, #gluon.auto_encoding>
1246+
# CHECK: [[Y:%.*]] = arith.constant dense<2> : tensor<1x16xi32, [[BLOCKED]]>
1247+
x = ttgl.full([16, 1], 1, ttgl.int32, layout=ttgl.AutoLayout())
1248+
y = ttgl.full([1, 16], 2, ttgl.int32, layout=ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]))
1249+
1250+
# CHECK: [[XCVT:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1251+
# CHECK: [[XBCAST:%.*]] = tt.broadcast [[XCVT]]
1252+
# CHECK: [[YBCAST:%.*]] = tt.broadcast [[Y]]
1253+
# CHECK: arith.addi [[XBCAST]], [[YBCAST]] : tensor<16x16xi32, [[BLOCKED]]>
1254+
_ = x + y
1255+
1256+
# CHECK: [[XCVT2:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1257+
# CHECK: [[YBCAST2:%.*]] = tt.broadcast [[Y]]
1258+
# CHECK: [[XBCAST2:%.*]] = tt.broadcast [[XCVT2]]
1259+
# CHECK: arith.muli [[YBCAST2]], [[XBCAST2]] : tensor<16x16xi32, [[BLOCKED]]>
1260+
_ = y * x

python/test/unit/language/test_compile_errors.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_xpu
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna4, is_xpu
1111

1212

1313
def format_exception(type, value, tb):
@@ -364,9 +364,9 @@ def test_fp8_support(fresh_triton_cache, dtype):
364364
if cc >= (8, 9):
365365
supported_dtypes.append(tl.float8e4nv)
366366
elif is_hip():
367-
supported_dtypes.append(tl.float8e4nv)
368-
if is_hip_cdna3():
369-
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
367+
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
368+
if is_hip_cdna4():
369+
warning_dtypes += [tl.float8e4b8, tl.float8e5b16]
370370
elif is_xpu():
371371
supported_dtypes += [tl.float8e4b15, tl.float8e4nv]
372372

@@ -376,7 +376,11 @@ def dtype_kernel(dtype: tl.constexpr):
376376
tl.dot(a, a)
377377

378378
if dtype in warning_dtypes:
379-
ctx = pytest.warns(UserWarning, match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures")
379+
if is_cuda():
380+
ctx = pytest.warns(UserWarning,
381+
match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures")
382+
elif is_hip_cdna4():
383+
ctx = pytest.warns(UserWarning, match=r"AMD gfx942 specific and not supported on gfx950")
380384
elif dtype in supported_dtypes:
381385
ctx = contextlib.nullcontext()
382386
else:

python/test/unit/language/test_conversions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4, is_xpu
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4, is_xpu
1111

1212

1313
def matching_int(dtype):
@@ -297,6 +297,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
297297
('float8e4nv', 'float32'),
298298
299299
('float8e4b8', 'float32'),
300+
('float8e4b8', 'bfloat16'),
300301
('float8e4b8', 'float16'),
301302
302303
('float8e5b16', 'float32'),
@@ -316,12 +317,13 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
316317
elif is_hip():
317318
if (src_dtype == 'float8e4nv' and not (is_hip_cdna3() or is_hip_cdna4())):
318319
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
319-
if (src_dtype in ('float8e4b15') or
320-
(src_dtype in ('float8e4b8', 'float8e5b16') and not is_hip_cdna3())):
320+
if src_dtype == 'float8e4b15':
321321
# If the dtype should error out in the given device, we assert that and return
322322
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
323323
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
324324
return
325+
if src_dtype in ('float8e4b8', 'float8e5b16') and is_hip_cdna2():
326+
pytest.skip(f"{src_dtype} is not supported on AMDGPU CDNA2")
325327
elif is_xpu():
326328
if (src_dtype in ('float8e4b8', 'float8e5b16')):
327329
# If the dtype should error out in the given device, we assert that and return
@@ -379,8 +381,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
379381
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
380382

381383
if is_hip():
382-
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_cdna3():
383-
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
384+
if dst_dtype in ('float8e4b8', 'float8e5b16') and is_hip_cdna2():
385+
pytest.skip(f"{dst_dtype} is not supported on AMDGPU CDNA2")
384386

385387
if is_xpu():
386388
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne':

python/test/unit/language/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6139,9 +6139,9 @@ def kernel(Out):
61396139

61406140

61416141
def test_globaltimer(device):
6142-
if is_hip_cdna2():
6143-
pytest.skip("test_globaltimer is flaky on gfx90a")
61446142
check_cuda_or_hip(device)
6143+
if is_hip():
6144+
pytest.skip("test_globaltimer is flaky on AMD GPUs")
61456145

61466146
@triton.jit
61476147
def kernel(Out1, Out2, func: tl.constexpr):

python/triton/_filecheck.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ def run_filecheck(name, module_str, check_template):
4242
temp.write(check_template)
4343

4444
try:
45-
subprocess.check_output([filecheck_path, temp_expected, "--input-file", temp_module],
46-
stderr=subprocess.STDOUT)
45+
subprocess.check_output(
46+
[filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
47+
stderr=subprocess.STDOUT)
4748
except subprocess.CalledProcessError as error:
4849
decoded = error.output.decode('unicode_escape')
4950
raise ValueError(decoded)
@@ -60,8 +61,10 @@ def run_parser(kernel_fn):
6061
ir.load_dialects(context)
6162
stub_backend.load_dialects(context)
6263

63-
extra_options = src.parse_options()
64-
options = stub_backend.parse_options(dict(**extra_options))
64+
options = dict(sanitize_overflow=False)
65+
options.update(src.parse_options())
66+
67+
options = stub_backend.parse_options(options)
6568
codegen_fns = stub_backend.get_codegen_implementation(options)
6669
module_map = stub_backend.get_module_map()
6770
module = src.make_ir(options, codegen_fns, module_map, context)

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444

4545
_IMPORT_FROM_TRITON: List[str] = [
46+
"broadcast",
4647
"expand_dims",
4748
"inline_asm_elementwise",
4849
"join",
@@ -341,14 +342,14 @@ def _keep_alive(self, _semantic: GluonSemantic = None) -> None:
341342

342343

343344
@builtin
344-
def arange(start, end, layout, _semantic=None):
345+
def arange(start, end, layout=None, _semantic=None):
345346
"""
346347
Generate a sequence tensor with values in [start, end) using a specified layout.
347348
348349
Args:
349350
start (int): Inclusive start of the sequence.
350351
end (int): Exclusive end of the sequence.
351-
layout (DistributedLayout): The layout of the output tensor.
352+
layout (DistributedLayout): The layout of the output tensor. Defaults to AutoLayout.
352353
353354
Returns:
354355
tensor: A 1D tensor containing sequential values.

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
112112
lhs_shape = lhs_ty.get_block_shapes()
113113
rhs_shape = rhs_ty.get_block_shapes()
114114
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
115-
if lhs_ty.layout != rhs_ty.layout:
115+
116+
is_lhs_auto = isinstance(lhs_ty.layout, AutoLayout)
117+
is_rhs_auto = isinstance(rhs_ty.layout, AutoLayout)
118+
if is_lhs_auto and not is_rhs_auto:
119+
lhs = self.convert_layout(lhs, rhs_ty.layout)
120+
elif is_rhs_auto and not is_lhs_auto:
121+
rhs = self.convert_layout(rhs, lhs_ty.layout)
122+
elif lhs_ty.layout != rhs_ty.layout:
116123
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")
117124

118125
lhs = self.broadcast_impl_shape(lhs, ret_shape)
@@ -121,6 +128,8 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
121128

122129
def arange(self, start, end, layout):
123130
shape = [end - start]
131+
if layout is None:
132+
layout = AutoLayout()
124133
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
125134
return super().arange(start, end, ret_ty=ret_ty)
126135

@@ -138,7 +147,7 @@ def full(self, shape, value, dtype, layout):
138147
scalar = self.make_scalar(value, dtype)
139148
return self.splat(scalar, shape, layout)
140149

141-
def convert_layout(self, value, layout, assert_trivial):
150+
def convert_layout(self, value, layout, assert_trivial=False):
142151
ty = value.type
143152
_check(isinstance(ty, ttgl.distributed_type),
144153
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")

python/triton/language/semantic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,18 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
14891489
lhs = self.cast(lhs, tl.float16)
14901490
rhs = self.cast(rhs, tl.float16)
14911491

1492+
uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
1493+
uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
1494+
if uses_fp8e4b8 or uses_fp8e5b16:
1495+
type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
1496+
if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1497+
arch = self.builder.options.arch
1498+
warnings.warn(
1499+
f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
1500+
f"Please use OCP fp8 variants on {arch} for performance")
1501+
lhs = self.cast(lhs, tl.float16)
1502+
rhs = self.cast(rhs, tl.float16)
1503+
14921504
if input_precision is None:
14931505
input_precision = self.builder.options.default_dot_input_precision
14941506

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
4444
// COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}} {alias_scopes = [[[$ASYNC_COPY_SCOPE]]]
4545
// Check that store for 'other' has alias information set
4646
// COMMON: llvm.store {{.*}} {alias_scopes = [[[$LOCAL_LOAD_SCOPE]]], noalias_scopes = [[[$ASYNC_COPY_SCOPE]]]
47-
%65 = amdgpu.buffer_load_to_local %arg1[%arg2] mask=%mask other=%other into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f32>[tensor<8x64xi32, #blocked>] tensor<8x64xf32, #blocked> -> <8x64xf32, #shared, #smem, mutable>
47+
%65 = amdgpu.buffer_load_to_local %arg1[%arg2] mask=%mask other=%other into %arg3 : <f32>[tensor<8x64xi32, #blocked>] tensor<8x64xf32, #blocked> -> <8x64xf32, #shared, #smem, mutable>
4848

4949
// COMMON: llvm.return
5050
tt.return

0 commit comments

Comments
 (0)