Skip to content

Commit 73fc9a9

Browse files
author
morelos
committed
[ET-VK] additional dtype runtime support along with their aliases
Pull Request resolved: #11365 # Context This diff generally aims provide improvements to the existing framework for defining dtype GLSL shader variants, along with setting up support that would be necessary for future shader implementations that wish to support int64 and double dtypes. In order to allow doubles as input/output dtypes for dequantization and quantization, this diff will create the dtype runtime support on the Vulkan backend in Executorch by establishing the relationship between different tensor types and different GLSL types. # Changes The main changes are included in `gen_vulkan_spv.py` which maps the relationship between different dtypes and their GLSL types. For instance, we add aliases for every common dtype which includes `uint8`, `int8`, `uint16`, `int16`, `uint32`, `int32`, `uint64`, `int64`, and `double`. We maintain support for `int`, `uint`, and `bool` alises such that we can avoid making the change overly complex while supporting the most common recognizable alias (int). Furthermore, this diff also modifies the vulkan api to incorporate new types, namely `uint32_t`, `double`, the int16 and int64 variants. We then make sure that the `ShaderNameUtils` (which is commonly used by most operators for creating their variant names), utilizes the new aliasing. Beyond that we also throw an exception to disallow YAML files to include just "int", and to be more specific, like with "int32". We then modify dozens of files to switch to the new alias of int32. Furthermore, we also include double in certain shaders that are used as intermediaries for image to buffer to nchw converisons. ghstack-source-id: 290376491 @exported-using-ghexport Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/)
1 parent 6cb1fac commit 73fc9a9

28 files changed

+190
-68
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,52 +56,97 @@
5656
TYPE_MAPPINGS: Dict[str, Any] = {
5757
"IMAGE_T": {
5858
3: {
59+
"double": "image3D",
5960
"float": "image3D",
6061
"half": "image3D",
61-
"int": "iimage3D",
62-
"uint": "uimage3D",
62+
# integer dtypes
6363
"int8": "iimage3D",
6464
"uint8": "uimage3D",
65+
"int16": "iimage3D",
66+
"uint16": "uimage3D",
67+
"int32": "iimage3D",
68+
"uint32": "uimage3D",
69+
"int64": "iimage3D",
70+
"uint64": "uimage3D",
71+
# common dtype aliases
6572
"bool": "uimage3D",
73+
"int": "iimage3D",
74+
"uint": "uimage3D",
6675
},
6776
2: {
77+
"double": "image2D",
6878
"float": "image2D",
6979
"half": "image2D",
70-
"int": "iimage2D",
71-
"uint": "uimage2D",
80+
# integer dtypes
7281
"int8": "iimage2D",
7382
"uint8": "uimage2D",
83+
"int16": "iimage2D",
84+
"uint16": "uimage2D",
85+
"int32": "iimage2D",
86+
"uint32": "uimage2D",
87+
"int64": "iimage2D",
88+
"uint64": "uimage2D",
89+
# common dtype aliases
7490
"bool": "uimage2D",
91+
"int": "iimage2D",
92+
"uint": "uimage2D",
7593
},
7694
},
7795
"SAMPLER_T": {
7896
3: {
97+
"double": "sampler3D",
7998
"float": "sampler3D",
8099
"half": "sampler3D",
81-
"int": "isampler3D",
82-
"uint": "usampler3D",
100+
# integer dtypes
83101
"int8": "isampler3D",
84102
"uint8": "usampler3D",
103+
"int16": "isampler3D",
104+
"uint16": "usampler3D",
105+
"int32": "isampler3D",
106+
"uint32": "usampler3D",
107+
"int64": "isampler3D",
108+
"uint64": "usampler3D",
109+
# common dtype aliases
85110
"bool": "usampler3D",
111+
"int": "isampler3D",
112+
"uint": "usampler3D",
86113
},
87114
2: {
115+
"double": "sampler2D",
88116
"float": "sampler2D",
89117
"half": "sampler2D",
90-
"int": "isampler2D",
91-
"uint": "usampler2D",
118+
# integer dtypes
92119
"int8": "isampler2D",
93120
"uint8": "usampler2D",
121+
"int16": "isampler2D",
122+
"uint16": "usampler2D",
123+
"int32": "isampler2D",
124+
"uint32": "usampler2D",
125+
"int64": "isampler2D",
126+
"uint64": "usampler2D",
127+
# common dtype aliases
94128
"bool": "usampler2D",
129+
"int": "isampler2D",
130+
"uint": "usampler2D",
95131
},
96132
},
97133
"IMAGE_FORMAT": {
134+
"double": "rgba32f",
98135
"float": "rgba32f",
99136
"half": "rgba16f",
100-
"int": "rgba32i",
101-
"uint": "rgba32ui",
137+
# integer dtypes
102138
"int8": "rgba8i",
103139
"uint8": "rgba8ui",
140+
"int16": "rgba16i",
141+
"uint16": "rgba16ui",
142+
"int32": "rgba32i",
143+
"uint32": "rgba32ui",
144+
"int64": "rgba32i",
145+
"uint64": "rgba32ui",
146+
# common dtype aliases
104147
"bool": "rgba8ui",
148+
"int": "rgba32i",
149+
"uint": "rgba32ui",
105150
},
106151
}
107152

@@ -118,33 +163,47 @@ def define_variable(name: str) -> str:
118163
def buffer_scalar_type(dtype: str) -> str:
119164
if dtype == "half":
120165
return "float16_t"
121-
elif dtype[-1] == "8":
122-
return dtype + "_t"
166+
elif dtype == "float":
167+
return "float"
168+
elif dtype == "double":
169+
return "float64_t"
170+
# integer dtype alias conversion
123171
elif dtype == "bool":
124172
return "uint8_t"
173+
# we don't want to append _t for int32 or uint32 as int is already 32bit
174+
elif dtype == "int32" or dtype == "uint32":
175+
return "int" if dtype == "int32" else "uint"
176+
elif dtype[-1].isdigit():
177+
return dtype + "_t"
125178
return dtype
126179

127180

128181
def buffer_gvec_type(dtype: str, n: int) -> str:
129182
if n == 1:
130183
return buffer_scalar_type(dtype)
131184

132-
if dtype == "float":
133-
return f"vec{n}"
134-
if dtype == "uint":
135-
return f"uvec{n}"
136-
elif dtype == "half":
137-
return f"f16vec{n}"
138-
elif dtype == "int":
139-
return f"ivec{n}"
140-
elif dtype == "int8":
141-
return f"i8vec{n}"
142-
elif dtype == "uint8":
143-
return f"u8vec{n}"
144-
elif dtype == "bool":
145-
return f"u8vec{n}"
146-
147-
raise AssertionError(f"Invalid dtype: {dtype}")
185+
dtype_map = {
186+
"half": f"f16vec{n}",
187+
"float": f"vec{n}",
188+
"double": f"vec{n}", # No 64bit image format support in GLSL
189+
"int8": f"i8vec{n}",
190+
"uint8": f"u8vec{n}",
191+
"int16": f"i16vec{n}",
192+
"uint16": f"u16vec{n}",
193+
"int32": f"ivec{n}",
194+
"int": f"ivec{n}",
195+
"uint32": f"uvec{n}",
196+
"uint": f"uvec{n}",
197+
"int64": f"ivec{n}", # No 64bit image format support in GLSL
198+
"uint64": f"uvec{n}", # No 64bit image format support in GLSL
199+
"bool": f"u8vec{n}",
200+
}
201+
202+
vector_type = dtype_map.get(dtype)
203+
if vector_type is None:
204+
raise AssertionError(f"Invalid dtype: {dtype}")
205+
206+
return vector_type
148207

149208

150209
def texel_type(dtype: str) -> str:
@@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
365424
if dtype == "half":
366425
nbit = "16bit"
367426
glsl_type = "float16"
368-
elif dtype == "int16" or dtype == "uint16":
369-
nbit = "16bit"
370-
glsl_type = "int16"
371-
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
427+
elif dtype == "double":
428+
# We only need to allow float64_t type usage
429+
glsl_type = "float64"
430+
elif dtype in ["int8", "uint8", "bool"]:
372431
nbit = "8bit"
373432
glsl_type = "int8"
433+
elif dtype in ["int16", "uint16"]:
434+
nbit = "16bit"
435+
glsl_type = "int16"
436+
elif dtype in ["int64", "uint64"]:
437+
# We only need to allow int64_t and uint64_t type usage
438+
glsl_type = "int64"
374439

375-
if nbit is not None and glsl_type is not None:
440+
if nbit is not None:
376441
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
442+
if glsl_type is not None:
377443
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
378444

379445
return out_str
@@ -629,6 +695,10 @@ def generateVariantCombinations(
629695

630696
elif "VALUE" in value:
631697
suffix = value.get("SUFFIX", value["VALUE"])
698+
if value["VALUE"] in ["int", "uint"]:
699+
raise ValueError(
700+
f"Use int32 or uint32 instead of {value['VALUE']}"
701+
)
632702
param_values.append((param_name, suffix, value["VALUE"]))
633703

634704
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
arange:
88
parameter_names_with_default_values:
99
NDIM: 3
10-
DTYPE: int
10+
DTYPE: int32
1111
STORAGE: texture3d
1212
PACKING: C_packed
1313
generate_variant_forall:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: int32
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: int32
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int
20+
- VALUE: int32
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18+
- VALUE: int32
1819
shader_variants:
1920
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ buffer_to_nchw:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: double
1717
- VALUE: int8
1818
- VALUE: uint8
19+
- VALUE: int32
1920
shader_variants:
2021
- NAME: buffer_to_nchw
2122
- NAME: buffer_to_nchw_no_pc

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: embedding

0 commit comments

Comments
 (0)