|
4 | 4 | # See https://llvm.org/LICENSE.txt for license information. |
5 | 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | 6 |
|
| 7 | +import copy |
7 | 8 | import sympy |
8 | 9 | import functools |
9 | 10 | from typing import Any, Optional, Dict |
|
34 | 35 | from ...compiler.vector_codegen import ( |
35 | 36 | cast_kernel_buffer, |
36 | 37 | cast_py_literal, |
| 38 | + cast_py_value, |
37 | 39 | cast_vector, |
38 | 40 | ) |
39 | 41 |
|
40 | | -from ...ops.wave_ops import get_custom, read, write, CustomOp |
| 42 | +from ...ops.wave_ops import ( |
| 43 | + CustomOp, |
| 44 | + gather_to_lds, |
| 45 | + get_custom, |
| 46 | + read, |
| 47 | + write, |
| 48 | +) |
41 | 49 |
|
42 | 50 | from ..utils.general_utils import get_fastest_index, infer_dim |
| 51 | +from ..utils.mapping_utils import transform_index_on_mapping |
43 | 52 | from ..utils.symbol_utils import safe_subs, subs_idxc |
44 | 53 |
|
45 | 54 | from ..._support.indexing import IndexingContext, IndexExpr, IndexSequence, IndexSymbol |
|
48 | 57 |
|
49 | 58 | from .emitter import ( |
50 | 59 | WaveEmitter, |
51 | | - handle_op, |
52 | 60 | add_emitter_subs, |
53 | 61 | gen_sympy_index, |
54 | 62 | get_constant_attr, |
| 63 | + get_type_or_element_type, |
| 64 | + handle_op, |
55 | 65 | ) |
56 | 66 |
|
57 | 67 |
|
@@ -883,3 +893,82 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): |
883 | 893 | mask, |
884 | 894 | offsets_vec, |
885 | 895 | ) |
| 896 | + |
| 897 | + |
| 898 | +@handle_op(gather_to_lds) |
| 899 | +def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): |
| 900 | + try: |
| 901 | + ( |
| 902 | + src, |
| 903 | + dst, |
| 904 | + src_idx, |
| 905 | + dst_idx, |
| 906 | + element_type, |
| 907 | + elements_per_thread, |
| 908 | + src_mapping, |
| 909 | + dst_mapping, |
| 910 | + ) = node.args |
| 911 | + except ValueError as e: |
| 912 | + raise ValidationError("Malformed arguments") from e |
| 913 | + |
| 914 | + element_type = IrType.parse(element_type.dtype.ir_type_asm()) |
| 915 | + |
| 916 | + src_symbolic_shape = _get_symbolic_shape(src) |
| 917 | + dst_symbolic_shape = _get_symbolic_shape(dst) |
| 918 | + |
| 919 | + src = cast_py_value(emitter, src) |
| 920 | + dst = cast_py_value(emitter, dst) |
| 921 | + src_data_type = get_type_or_element_type(src.ir_value.type) |
| 922 | + dst_data_type = get_type_or_element_type(dst.ir_value.type) |
| 923 | + |
| 924 | + if not ( |
| 925 | + MemRefType.isinstance(src.ir_value.type) |
| 926 | + and MemRefType.isinstance(dst.ir_value.type) |
| 927 | + ): |
| 928 | + op = get_custom(node) |
| 929 | + raise ValidationError( |
| 930 | + f"Expected src and dst to be of Memref type for\n" |
| 931 | + f"{op}\nGot\n" |
| 932 | + f"src: {src.ir_value.type}\n" |
| 933 | + f"dst: {dst.ir_value.type}\n" |
| 934 | + ) |
| 935 | + |
| 936 | + if src_data_type != dst_data_type: |
| 937 | + op = get_custom(node) |
| 938 | + raise ValidationError( |
| 939 | + f"Expected src and dst to have same data type for\n" |
| 940 | + f"{op}\nGot\n" |
| 941 | + f"src: {src_data_type} vs dst: {dst_data_type}\n" |
| 942 | + ) |
| 943 | + |
| 944 | + src = src.ir_value |
| 945 | + dst = dst.ir_value |
| 946 | + |
| 947 | + if src_mapping: |
| 948 | + src_idx = transform_index_on_mapping(src_mapping, src_symbolic_shape, src_idx) |
| 949 | + if dst_mapping: |
| 950 | + dst_idx = transform_index_on_mapping(dst_mapping, dst_symbolic_shape, dst_idx) |
| 951 | + |
| 952 | + store_type = VectorType.get((elements_per_thread,), element_type) |
| 953 | + |
| 954 | + src_index, src_index_wg, src_index_th = _build_start_indices(emitter, src_idx) |
| 955 | + dst_index, _, _ = _build_start_indices(emitter, dst_idx) |
| 956 | + |
| 957 | + if False: # TODO: Buffer stuff needs upstream fixes |
| 958 | + strides = strides_from_symbolic_shape( |
| 959 | + IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True |
| 960 | + ) |
| 961 | + strides = [gen_sympy_index(add_emitter_subs(emitter), s) for s in strides] |
| 962 | + |
| 963 | + src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) |
| 964 | + src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) |
| 965 | + |
| 966 | + src_index = [offset_th] |
| 967 | + |
| 968 | + amdgpu_d.gather_to_lds( |
| 969 | + src=src, |
| 970 | + src_indices=src_index, |
| 971 | + dst=dst, |
| 972 | + dst_indices=dst_index, |
| 973 | + transfer_type=store_type, |
| 974 | + ) |
0 commit comments