Skip to content

Commit aeb4d4f

Browse files
authored
[TESTS] Reducing warning messages in test_core.py (#6869)
1 parent 9e92724 commit aeb4d4f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/test/unit/language/test_core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def kernel(Z, X, SIZE: tl.constexpr):
350350
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
351351
# inputs
352352
x = numpy_random(SIZE, dtype_str=dtype_x)
353-
if 'log' in expr:
353+
# avoid log/sqrt of negative numbers
354+
if 'log' in expr or 'sqrt' in expr:
354355
x = np.abs(x) + 0.01
355356
# reference result
356357
z_ref = eval(expr if numpy_expr is None else numpy_expr)
@@ -1270,7 +1271,7 @@ def kernel():
12701271
a = tl.arange(0, 64).reshape(2, 4, 8).trans((2, 1, 0))
12711272
tl.static_assert(a.shape == [tl.constexpr(8), tl.constexpr(4), tl.constexpr(2)])
12721273

1273-
a = tl.arange(0, 64).view(2, 4, 8)
1274+
a = tl.reshape(tl.arange(0, 64), 2, 4, 8, can_reorder=True)
12741275
tl.static_assert(a.shape == [tl.constexpr(2), tl.constexpr(4), tl.constexpr(8)])
12751276

12761277
kernel[(1, )]()
@@ -1543,6 +1544,8 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device):
15431544
if is_interpreter():
15441545
if dtype_x_str == 'float16' or dtype_x_str == 'bfloat16':
15451546
pytest.skip("Only test atomic bfloat16/float16 ops on GPU")
1547+
if "uint" in dtype_x_str and mode in ["min_neg", "all_neg"]:
1548+
pytest.skip("uint cannot be negative")
15461549

15471550
n_programs = 5
15481551

@@ -1745,7 +1748,7 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
17451748
xoffset = tl.program_id(0) * XBLOCK
17461749
x_idx = xoffset + tl.arange(0, XBLOCK)[:]
17471750
mask = x_idx < shape0 * shape1
1748-
mask = mask and (x_idx % mask_step != 0)
1751+
mask = mask & (x_idx % mask_step != 0)
17491752
idx_base = shape1 * (x_idx // shape1)
17501753
idx_offset = tl.load(idx_ptr + x_idx, mask)
17511754
in_elem = tl.load(in_ptr + x_idx, mask)
@@ -2758,7 +2761,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const
27582761

27592762
elif op == 'cummax':
27602763
# NumPy does not have cummax
2761-
z = z.astype(np.int64)
2764+
z = np.empty_like(x, dtype=np.int64)
27622765
z_ref = torch.cummax(torch.from_numpy(x_in.copy()), axis=axis).indices.numpy()
27632766
if reverse:
27642767
z_ref = x_in.shape[axis] - np.flip(z_ref, axis) - 1
@@ -7515,6 +7518,7 @@ def g(y, dtype):
75157518

75167519

75177520
@pytest.mark.interpreter
7521+
@pytest.mark.filterwarnings("ignore:If conditional called with multidimensional Tensor*")
75187522
def test_unsplat(device):
75197523

75207524
@triton.jit

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,9 @@ struct FpToFpOpConversion
431431
// mul{.rnd}.bf16 and mul{.rnd}.bf16x2 requires sm_90 or higher.
432432
{{F8E5M2TyID, BF16TyID, undefRounding},
433433
Fp8E5M2_to_Bf16(computeCapability >= 90)},
434+
// cvt with .bf16.f16' requires .target sm_90 or higher
434435
{{F8E4M3TyID, BF16TyID, undefRounding},
435-
Fp8E4M3Nv_to_Bf16(computeCapability >= 89)},
436+
Fp8E4M3Nv_to_Bf16(computeCapability >= 90)},
436437
// BF16 -> F8
437438
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
438439
Bf16_to_Fp8E5M2(computeCapability >= 89)},

0 commit comments

Comments
 (0)