-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathmessage_passing_load_store.py
More file actions
executable file
·204 lines (163 loc) · 5.97 KB
/
message_passing_load_store.py
File metadata and controls
executable file
·204 lines (163 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
import argparse
import torch
import triton
import triton.language as tl
import random
import iris
@triton.jit
def producer_kernel(
source_buffer, # tl.tensor: pointer to source data
target_buffer, # tl.tensor: pointer to target data
flag, # tl.tensor: pointer to flags
buffer_size, # int32: total number of elements
producer_rank: tl.constexpr,
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)
# Compute start index of this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Guard for out-of-bounds accesses
mask = offsets < buffer_size
# Load chunk from source buffer
values = iris.load(source_buffer + offsets, producer_rank, producer_rank, heap_bases_ptr, mask=mask)
# Store chunk to target buffer
iris.store(
target_buffer + offsets,
values,
producer_rank,
consumer_rank,
heap_bases_ptr,
mask=mask,
)
# Set flag to signal completion
tl.store(flag + pid, 1)
@triton.jit
def consumer_kernel(
buffer, # tl.tensor: pointer to shared buffer (read from target_rank)
flag, # tl.tensor: sync flag per block
buffer_size, # int32: total number of elements
consumer_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < buffer_size
# Spin-wait until writer sets flag[pid] = 1
done = tl.load(flag + pid)
while done == 0:
done = tl.load(flag + pid)
# Read from the target buffer (written by producer)
values = iris.load(buffer + offsets, consumer_rank, consumer_rank, heap_bases_ptr, mask=mask)
# Do something with values...
# (Here you might write to output, do computation, etc.)
values = values * 2
# Store chunk to target buffer
iris.store(
buffer + offsets,
values,
consumer_rank,
consumer_rank,
heap_bases_ptr,
mask=mask,
)
# Optionally reset the flag for next iteration
tl.store(flag + pid, 0)
torch.manual_seed(123)
random.seed(123)
def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"bf16": torch.bfloat16,
}
try:
return dtype_map[datatype]
except KeyError:
print(f"Unknown datatype: {datatype}")
exit(1)
def parse_args():
parser = argparse.ArgumentParser(
description="Parse Message Passing configuration.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-t",
"--datatype",
type=str,
default="fp32",
choices=["fp16", "fp32", "int8", "bf16"],
help="Datatype of computation",
)
parser.add_argument("-s", "--buffer_size", type=int, default=4096, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")
parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
return vars(parser.parse_args())
def main():
args = parse_args()
shmem = iris.iris(args["heap_size"])
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
# Allocate source and destination buffers on the symmetric heap
source_buffer = shmem.zeros(args["buffer_size"], device="cuda", dtype=dtype)
destination_buffer = shmem.randn(args["buffer_size"], device="cuda", dtype=dtype)
if world_size != 2:
raise ValueError("This example requires exactly two processes.")
producer_rank = 0
consumer_rank = 1
n_elements = source_buffer.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
num_blocks = triton.cdiv(n_elements, args["block_size"])
# Allocate flags on the symmetric heap
flags = shmem.zeros((num_blocks,), device="cuda", dtype=torch.int32)
if cur_rank == producer_rank:
shmem.log(f"Rank {cur_rank} is sending data to rank {consumer_rank}.")
kk = producer_kernel[grid](
source_buffer,
destination_buffer,
flags,
n_elements,
producer_rank,
consumer_rank,
args["block_size"],
shmem.get_heap_bases(),
)
else:
shmem.log(f"Rank {cur_rank} is receiving data from rank {producer_rank}.")
kk = consumer_kernel[grid](
destination_buffer, flags, n_elements, consumer_rank, args["block_size"], shmem.get_heap_bases()
)
shmem.barrier()
shmem.log(f"Rank {cur_rank} has finished sending/receiving data.")
shmem.log("Validating output...")
success = True
if cur_rank == consumer_rank:
expected = source_buffer * 2
diff_mask = ~torch.isclose(destination_buffer, expected, atol=1)
breaking_indices = torch.nonzero(diff_mask, as_tuple=False)
if not torch.allclose(destination_buffer, expected, atol=1):
max_diff = (destination_buffer - expected).abs().max().item()
shmem.log(f"Max absolute difference: {max_diff}")
for idx in breaking_indices:
idx = tuple(idx.tolist())
computed_val = destination_buffer[idx]
expected_val = expected[idx]
shmem.log(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}")
success = False
break
if success:
shmem.log("Validation successful.")
else:
shmem.log("Validation failed.")
shmem.barrier()
if __name__ == "__main__":
main()