Skip to content

Commit 5ee0bb9

Browse files
committed
Use the half-precision floating-point format as the data type for arguments
1 parent 0d220a5 commit 5ee0bb9

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

add.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def grid(meta):
5959

6060
torch.manual_seed(0)
6161
size = 98432
62-
lhs = torch.rand(size, device="cuda")
63-
rhs = torch.rand(size, device="cuda")
62+
dtype = torch.float16
63+
lhs = torch.rand(size, dtype=dtype, device="cuda")
64+
rhs = torch.rand(size, dtype=dtype, device="cuda")
6465
ninetoothed_output = add(lhs, rhs)
6566
torch_output = lhs + rhs
6667
triton_output = triton_add(lhs, rhs)
@@ -92,8 +93,8 @@ def grid(meta):
9293
)
9394
)
9495
def benchmark(size, provider):
95-
lhs = torch.rand(size, device="cuda", dtype=torch.float32)
96-
rhs = torch.rand(size, device="cuda", dtype=torch.float32)
96+
lhs = torch.rand(size, device="cuda", dtype=torch.float16)
97+
rhs = torch.rand(size, device="cuda", dtype=torch.float16)
9798
quantiles = [0.5, 0.2, 0.8]
9899

99100
if provider == "ninetoothed":

conv2d.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ def grid(meta):
200200
torch.manual_seed(0)
201201
n, c, h, w = 4, 3, 224, 224
202202
k, _, r, s = 8, c, 3, 3
203-
input = torch.randn(n, c, h, w, device="cuda")
204-
filter = torch.randn(k, c, r, s, device="cuda")
203+
dtype = torch.float16
204+
input = torch.randn(n, c, h, w, dtype=dtype, device="cuda")
205+
filter = torch.randn(k, c, r, s, dtype=dtype, device="cuda")
205206
ninetoothed_output = conv2d(input, filter)
206207
torch_output = F.conv2d(input, filter)
207208
triton_output = triton_conv2d(input, filter)
@@ -233,8 +234,9 @@ def grid(meta):
233234
def benchmark(h, w, provider):
234235
n, c, _, _ = 64, 3, h, w
235236
k, _, r, s = 64, c, 3, 3
236-
input = torch.randn((n, c, h, w), device="cuda")
237-
filter = torch.randn((k, c, r, s), device="cuda")
237+
dtype = torch.float16
238+
input = torch.randn((n, c, h, w), dtype=dtype, device="cuda")
239+
filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda")
238240

239241
if provider == "ninetoothed":
240242
ms = triton.testing.do_bench(lambda: conv2d(input, filter))

softmax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ def triton_softmax(input):
7272

7373

7474
torch.manual_seed(0)
75-
input = torch.randn(1823, 781, device="cuda")
75+
input = torch.randn(1823, 781, dtype=torch.float16, device="cuda")
7676
ninetoothed_output = softmax(input)
7777
torch_output = torch.softmax(input, axis=-1)
7878
triton_output = triton_softmax(input)
7979
print(ninetoothed_output)
8080
print(torch_output)
8181
print(triton_output)
82-
if torch.allclose(ninetoothed_output, torch_output):
82+
if torch.allclose(ninetoothed_output, torch_output, atol=1e-5):
8383
print("✅ NineToothed and PyTorch match.")
8484
else:
8585
print("❌ NineToothed and PyTorch differ.")
@@ -103,7 +103,7 @@ def triton_softmax(input):
103103
)
104104
)
105105
def benchmark(m, n, provider):
106-
input = torch.randn(m, n, device="cuda", dtype=torch.float32)
106+
input = torch.randn(m, n, device="cuda", dtype=torch.float16)
107107
stream = torch.cuda.Stream()
108108
torch.cuda.set_stream(stream)
109109

0 commit comments

Comments
 (0)