diff --git a/iris/iris.py b/iris/iris.py index 172ffb5a..76da934e 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1710,7 +1710,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases): @triton.jit -def load(pointer, to_rank, from_rank, heap_bases, mask=None): +def load(pointer, to_rank, from_rank, heap_bases, mask=None, cache_modifier=None, volatile=False): """ Loads a value from the specified rank's memory location. @@ -1719,12 +1719,28 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): data from the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local load operation. + The `cache_modifier` parameter controls instruction-level cache behavior + by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits + in the global load instruction. These affect cache usage across the CU, + L2, and last-level caches. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local. from_rank (int): The rank ID from which to read the data. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the load. + + Supported values: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + Ensures global coherence by invalidating stale GPU cache lines. + + volatile (bool, optional): If True, disables compiler optimizations that + could reorder or eliminate the load. Returns: Block: The loaded value from the target memory location. @@ -1739,12 +1755,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None): >>> return data """ translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases) - result = tl.load(translated_ptr, mask=mask) + result = tl.load(translated_ptr, mask=mask, cache_modifier=cache_modifier, volatile=volatile) return result @triton.jit -def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): +def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None): """ Writes data to the specified rank's memory location. @@ -1753,6 +1769,11 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): the provided data to the target memory location. If the `from_rank` and `to_rank` are the same, this function performs a local store operation. + The `cache_modifier` parameter controls instruction-level cache behavior + by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits + in the global store instruction. These affect cache usage across the CU (L1), + L2, and last-level cache (LLC), following the CDNA ISA. + Args: pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the `from_rank`'s address space that will be translated to the `to_rank`'s address space. Must be the current rank where the pointer is local. value (Block): The tensor of elements to be stored. @@ -1760,6 +1781,13 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): to_rank (int): The rank ID to which the data will be written. heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. Returns: None @@ -1774,11 +1802,21 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None): >>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases) """ translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases) - tl.store(translated_ptr, value, mask=mask) + tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier) @triton.jit -def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): +def copy( + src_ptr, + dst_ptr, + from_rank, + to_rank, + cur_rank, + heap_bases, + mask=None, + load_cache_modifier=None, + store_cache_modifier=None, +): """ Copies data from the specified rank's memory into the destination rank's memory. This function performs the transfer by translating `src_ptr` from the `from_rank`'s address @@ -1796,6 +1834,19 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from the translated src_ptr[idx] and do not store to dst_ptr[idx]. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + Returns: None @@ -1824,12 +1875,14 @@ def copy(src_ptr, dst_ptr, from_rank, to_rank, cur_rank, heap_bases, mask=None): translated_src = tl.cast(from_base_byte + src_offset, src_ptr.dtype) translated_dst = tl.cast(to_base_byte + dst_offset, src_ptr.dtype) - data = tl.load(translated_src, mask=mask) - tl.store(translated_dst, data, mask=mask) + data = tl.load(translated_src, mask=mask, cache_modifier=load_cache_modifier) + tl.store(translated_dst, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def get( + from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None +): """ Copies data from the specified rank's memory to the current rank's local memory. @@ -1846,6 +1899,19 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + Returns: None @@ -1858,13 +1924,15 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases) - data = tl.load(translated_from_ptr, mask=mask) + data = tl.load(translated_from_ptr, mask=mask, cache_modifier=load_cache_modifier) - tl.store(to_ptr, data, mask=mask) + tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit -def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): +def put( + from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None +): """ Copies data from the current rank's local memory to the specified rank's memory. This function performs a memory write operation by loading data from the current @@ -1880,6 +1948,19 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks. mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: + - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. + - ".ca": Cache at all levels (CU, L2, LLC) with LRU policy. + - ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted. + - ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted. + + store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: + - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. + - ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later. + - ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU. + - ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC. + - ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU. + Returns: None @@ -1892,9 +1973,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None): """ translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases) - data = tl.load(from_ptr, mask=mask) + data = tl.load(from_ptr, mask=mask, cache_modifier=load_cache_modifier) - tl.store(translated_to_ptr, data, mask=mask) + tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier) @triton.jit diff --git a/tests/unittests/test_copy_cache_modifiers.py b/tests/unittests/test_copy_cache_modifiers.py new file mode 100644 index 00000000..b7c278ea --- /dev/null +++ b/tests/unittests/test_copy_cache_modifiers.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def copy_kernel_local_read_remote_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from local memory to remote memory (local read, remote write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from current rank to other ranks + for target_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * cur_rank + dest_data = results + BLOCK_SIZE * cur_rank + if load_cache_modifier is None and store_cache_modifier is None: + iris.copy(src_data + offsets, dest_data + offsets, cur_rank, target_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.copy( + src_data + offsets, + dest_data + offsets, + cur_rank, + target_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +@triton.jit +def copy_kernel_remote_read_local_write( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + """Copy from remote memory to local memory (remote read, local write)""" + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Copy from other ranks to current rank + for source_rank in range(num_ranks): + src_data = data + BLOCK_SIZE * source_rank + dest_data = results + BLOCK_SIZE * source_rank + if load_cache_modifier is None and store_cache_modifier is None: + iris.copy(src_data + offsets, dest_data + offsets, source_rank, cur_rank, cur_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.copy( + src_data + offsets, + dest_data + offsets, + source_rank, + cur_rank, + cur_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +# Remote stores (cross-GPU IPC) cannot use cache modifier bits +# Only default (None or empty string) works - cache bits break coherency +STORE_CACHE_MODIFIERS_REMOTE_WRITE = [None, ""] +# For testing remote reads (which work with all load modifiers), +# we can use all store modifiers since the store is local +LOAD_CACHE_MODIFIERS_REMOTE_READ = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS_LOCAL_WRITE = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS_REMOTE_WRITE)) +) +def test_copy_local_read_remote_write(load_cache_modifier, store_cache_modifier): + """Test copy: local read → remote write + + Direction: from_rank=cur_rank (local), to_rank=other (remote) + - Load: from LOCAL memory (all cache modifiers should work) + - Store: to REMOTE memory (only None/"" work, cache bits break coherency) + + This tests that load cache modifiers work for local reads. + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + grid = lambda meta: (1,) + copy_kernel_local_read_remote_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank copies its data to all other ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", + list(product(LOAD_CACHE_MODIFIERS_REMOTE_READ, STORE_CACHE_MODIFIERS_LOCAL_WRITE)), +) +def test_copy_remote_read_local_write(load_cache_modifier, store_cache_modifier): + """Test copy: remote read → local write + + Direction: from_rank=other (remote), to_rank=cur_rank (local) + - Load: from REMOTE memory (test if cache modifiers work for remote reads) + - Store: to LOCAL memory (all cache modifiers should work) + + This tests whether load cache modifiers work for remote reads. + """ + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + base = cur_rank + num_ranks + for i in range(num_ranks): + data[i, :] = base * (i + 1) + + results = shmem.zeros((num_ranks, BLOCK_SIZE), dtype=torch.float32) + grid = lambda meta: (1,) + copy_kernel_remote_read_local_write[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + + shmem.barrier() + + # Verify results - each rank pulls data from all ranks + for rank_id in range(num_ranks): + expected_value = (rank_id + num_ranks) * (rank_id + 1) + assert torch.allclose( + results[rank_id], torch.full((BLOCK_SIZE,), expected_value, dtype=torch.float32, device=results.device) + ), ( + f"Mismatch at rank {cur_rank}, slot {rank_id} with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) diff --git a/tests/unittests/test_get_cache_modifiers.py b/tests/unittests/test_get_cache_modifiers.py new file mode 100644 index 00000000..58cb9d48 --- /dev/null +++ b/tests/unittests/test_get_cache_modifiers.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def get_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty) + + # Loop over all ranks, get the stored data with cache modifiers + # We test default values set by the function when parameters are None + for target_rank in range(num_ranks): + if load_cache_modifier is None and store_cache_modifier is None: + iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.get( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.get( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.get( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + acc += tl.load(results + offsets, mask=mask) + + # Store the accumulated value back to the output + tl.store(results + offsets, acc, mask=mask) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_get_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test get (copy from other rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + get_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + # Verify the result - should get data from all ranks (including self) + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_load_cache_modifiers.py b/tests/unittests/test_load_cache_modifiers.py new file mode 100644 index 00000000..5c147300 --- /dev/null +++ b/tests/unittests/test_load_cache_modifiers.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def kernel( + data, + results, + source_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, + volatile: tl.constexpr, +): + pid = tl.program_id(0) + + partner = int((source_rank + num_ranks // 2) % num_ranks) + # Compute start index of this block + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # Guard for out-of-bounds accesses + mask = offsets < BLOCK_SIZE + + if cache_modifier is None: + result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, volatile=volatile) + else: + result = iris.load( + data + offsets, + source_rank, + partner, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + volatile=volatile, + ) + + tl.store(results + offsets, result, mask=mask) + + +# Define cache modifiers and volatile options +CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +VOLATILE_OPTIONS = [False, True] + + +@pytest.mark.parametrize("cache_modifier,volatile", list(product(CACHE_MODIFIERS, VOLATILE_OPTIONS))) +def test_load_cache_modifiers(cache_modifier, volatile): + """Test load with various cache modifiers and volatile settings.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + source_rank = shmem.get_rank() + partner = int((source_rank + num_ranks // 2) % num_ranks) + + BLOCK_SIZE = 16 + data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile) + shmem.barrier() + + # Verify the result + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_put_cache_modifiers.py b/tests/unittests/test_put_cache_modifiers.py new file mode 100644 index 00000000..01b48037 --- /dev/null +++ b/tests/unittests/test_put_cache_modifiers.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris +from itertools import product + + +@triton.jit +def put_kernel( + data, + results, + cur_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + load_cache_modifier: tl.constexpr, + store_cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < BLOCK_SIZE + + # Put data to all ranks with cache modifiers + # We test default values set by the function when parameters are None + for target_rank in range(num_ranks): + if load_cache_modifier is None and store_cache_modifier is None: + iris.put(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask) + elif load_cache_modifier is None: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + store_cache_modifier=store_cache_modifier, + ) + elif store_cache_modifier is None: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + ) + else: + iris.put( + data + offsets, + results + offsets, + cur_rank, + target_rank, + heap_bases, + mask=mask, + load_cache_modifier=load_cache_modifier, + store_cache_modifier=store_cache_modifier, + ) + + +# Define cache modifiers for load and store operations +LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"] +STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize( + "load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS)) +) +def test_put_cache_modifiers(load_cache_modifier, store_cache_modifier): + """Test put (copy to other rank) with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + cur_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + data = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(data) + + shmem.barrier() + + grid = lambda meta: (1,) + put_kernel[grid]( + data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier + ) + shmem.barrier() + + # Verify the result - should have the data that was put + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print( + f"PUT test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}" + ) + print(e) + print("Expected:", expected) + print("Actual:", results) + raise diff --git a/tests/unittests/test_store_cache_modifiers.py b/tests/unittests/test_store_cache_modifiers.py new file mode 100644 index 00000000..892a09fb --- /dev/null +++ b/tests/unittests/test_store_cache_modifiers.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import torch +import triton +import triton.language as tl +import pytest +import iris + + +@triton.jit +def kernel( + data, + results, + destination_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + heap_bases: tl.tensor, + cache_modifier: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < BLOCK_SIZE + + # Load the data from src for this block + value = tl.load(data + offsets, mask=mask) + + # Store data to all ranks with the specified cache modifier + for dst_rank in range(num_ranks): + if cache_modifier is None: + iris.store(results + offsets, value, destination_rank, dst_rank, heap_bases, mask=mask) + else: + iris.store( + results + offsets, + value, + destination_rank, + dst_rank, + heap_bases, + mask=mask, + cache_modifier=cache_modifier, + ) + + +# Define cache modifiers for store operations +# Based on the provided cache modifier descriptions +CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"] + + +@pytest.mark.parametrize("cache_modifier", CACHE_MODIFIERS) +def test_store_cache_modifiers(cache_modifier): + """Test store with various cache modifiers.""" + shmem = iris.iris(1 << 20) + num_ranks = shmem.get_num_ranks() + heap_bases = shmem.get_heap_bases() + destination_rank = shmem.get_rank() + + BLOCK_SIZE = 16 + src = shmem.ones(BLOCK_SIZE, dtype=torch.float32) + results = shmem.zeros_like(src) + + shmem.barrier() + + grid = lambda meta: (1,) + kernel[grid](src, results, destination_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier) + shmem.barrier() + + # Verify the result + expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") + + try: + torch.testing.assert_close(results, expected, rtol=0, atol=0) + except AssertionError as e: + print(e) + print("Expected:", expected) + print("Actual:", results) + raise