Skip to content

Commit 26b45d8

Browse files
authored
Use 32-bit Philox consistently in randint4x even if offset is 64-bit (#6832)
Previously, passing a 64-bit offset to randint4x would cause Philox to use 64-bit arithmetic internally and return 64-bit integer tensors. Not only is this slow and register-hungry, but also it contradicts the description of randint4x that it returns four 32-bit integer tensors. We fix this to use 32-bit Philox even if the offset is 64-bit; the two halves of the offset are used to populate two of the four words in the 128-bit counter.
1 parent 676227a commit 26b45d8

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

python/test/unit/language/test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_randint(size, seed, device, dtype, const_seed):
125125
size = list(map(int, size.split(',')))
126126
torch_dtype = getattr(torch, dtype)
127127
numpy_dtype = getattr(np, f"u{dtype}")
128-
config = {'int32': PHILOX_32, 'int64': PHILOX_64}[dtype]
128+
config = PHILOX_32
129129

130130
@triton.jit
131131
def kernel(X, N, seed):

python/triton/language/random.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
5151
c1 = tl.to_tensor(c1)
5252
c2 = tl.to_tensor(c2)
5353
c3 = tl.to_tensor(c3)
54+
5455
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
5556
int_dtype = tl.uint32
5657
seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
@@ -60,6 +61,7 @@ def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
6061
int_dtype = tl.uint64
6162
seed_hi = tl.full((1, ), 0, dtype=int_dtype)
6263
seed_lo = seed
64+
6365
c0 = c0.to(int_dtype, bitcast=True)
6466
c1 = c1.to(int_dtype, bitcast=True)
6567
c2 = c2.to(int_dtype, bitcast=True)
@@ -96,8 +98,16 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
9698
:param offsets: The offsets to generate random numbers for.
9799
"""
98100
# _0 = tl.zeros(offset.shape, offset.dtype)
99-
_0 = offset * 0
100-
return philox(seed, offset, _0, _0, _0, n_rounds)
101+
102+
offset_lo = offset.to(tl.uint32)
103+
_0 = offset_lo * 0
104+
105+
if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
106+
offset_hi = (offset >> 32).to(tl.uint32)
107+
else:
108+
offset_hi = _0
109+
110+
return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
101111

102112

103113
# -------------------

0 commit comments

Comments
 (0)