@@ -29,9 +29,13 @@ def wait(
29
29
scope : str = "gpu" ,
30
30
hasSubsequentMemAccess : bool = True ,
31
31
) -> None :
32
- """Wait until all entries of the signal_pad slice are equal to the signal value.
32
+ """
33
+ Wait for global memory barriers.
34
+
35
+ Spins on global memory barriers until the signal values is observed on all barriers.
36
+
33
37
Args:
34
- signal_pad: The signal pad tensor / stack tensor to wait on
38
+ signal_pad: Tensor of global memory barriers to wait on
35
39
index: Indices to index into the signal_pad tensor
36
40
signal: the value to wait for
37
41
update: Atomically update the signal_pad tensor with this value once the signal is observed. (default: None)
@@ -179,16 +183,22 @@ def signal(
179
183
scope : str = "gpu" ,
180
184
hasPreviousMemAccess : bool = True ,
181
185
) -> torch .Tensor :
182
- """Set the signal_pad slice to the signal value.
186
+ """
187
+ Set global memory barriers.
188
+
189
+ Sets global memory barriers to the specified value.
190
+ If wait_for is not None, it waits for the barriers to be cleared before setting.
191
+
183
192
Args:
184
- signal_pad: The signal pad tensor / stack tensor to signal
193
+ signal_pad: Tensor of global memory barriers to set
185
194
index: Indices to index into the signal_pad tensor
186
195
signal: the value to send
187
196
wait_for: The value to wait for before sending the signal.
188
197
scope: The scope of the lock (default: 'gpu')
189
198
hasPreviousMemAccess: Whether the signal is preceded by a memory access (default: True)
199
+
190
200
Returns:
191
- The old value of the signal_pad slice before the update.
201
+ The old value of the global memory barriers before the update.
192
202
"""
193
203
raise exc .NotInsideKernel
194
204
0 commit comments