Skip to content

Commit f2c2380

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] quantization op shaders and impl"
Creating the quantize_per_tensor and quantize_per_token logic shaders and impl which are linked with the testing framework. NOTE: Currently the only input types supported are **half** (fp16) and **float** (fp32). The only output types supported are **byte** (uint8), **char** (int8), **short** (int16), **int** (int32). Differential Revision: [D75959064](https://our.internmc.facebook.com/intern/diff/D75959064/) [ghstack-poisoned]
2 parents 0c9c7a6 + 7e29a59 commit f2c2380

28 files changed

+269
-125
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -59,64 +59,94 @@
5959
"double": "image3D",
6060
"float": "image3D",
6161
"half": "image3D",
62-
"int": "iimage3D",
63-
"uint": "uimage3D",
62+
# integer dtypes
6463
"int8": "iimage3D",
6564
"uint8": "uimage3D",
66-
"bool": "uimage3D",
67-
"short": "iimage3D",
65+
"int16": "iimage3D",
6866
"uint16": "uimage3D",
67+
"int32": "iimage3D",
68+
"uint32": "uimage3D",
69+
"int64": "iimage3D",
70+
"uint64": "uimage3D",
71+
# common dtype aliases
72+
"bool": "uimage3D",
73+
"int": "iimage3D",
74+
"uint": "uimage3D",
6975
},
7076
2: {
7177
"double": "image2D",
7278
"float": "image2D",
7379
"half": "image2D",
74-
"int": "iimage2D",
75-
"uint": "uimage2D",
80+
# integer dtypes
7681
"int8": "iimage2D",
7782
"uint8": "uimage2D",
78-
"bool": "uimage2D",
79-
"short": "iimage2D",
83+
"int16": "iimage2D",
8084
"uint16": "uimage2D",
85+
"int32": "iimage2D",
86+
"uint32": "uimage2D",
87+
"int64": "iimage2D",
88+
"uint64": "uimage2D",
89+
# common dtype aliases
90+
"bool": "uimage2D",
91+
"int": "iimage2D",
92+
"uint": "uimage2D",
8193
},
8294
},
8395
"SAMPLER_T": {
8496
3: {
8597
"double": "sampler3D",
8698
"float": "sampler3D",
8799
"half": "sampler3D",
88-
"int": "isampler3D",
89-
"uint": "usampler3D",
100+
# integer dtypes
90101
"int8": "isampler3D",
91102
"uint8": "usampler3D",
92-
"bool": "usampler3D",
93-
"short": "isampler3D",
103+
"int16": "isampler3D",
94104
"uint16": "usampler3D",
105+
"int32": "isampler3D",
106+
"uint32": "usampler3D",
107+
"int64": "isampler3D",
108+
"uint64": "usampler3D",
109+
# common dtype aliases
110+
"bool": "usampler3D",
111+
"int": "isampler3D",
112+
"uint": "usampler3D",
95113
},
96114
2: {
97115
"double": "sampler2D",
98116
"float": "sampler2D",
99117
"half": "sampler2D",
100-
"int": "isampler2D",
101-
"uint": "usampler2D",
118+
# integer dtypes
102119
"int8": "isampler2D",
103120
"uint8": "usampler2D",
104-
"bool": "usampler2D",
105-
"short": "isampler2D",
121+
"int16": "isampler2D",
106122
"uint16": "usampler2D",
123+
"int32": "isampler2D",
124+
"uint32": "usampler2D",
125+
"int64": "isampler2D",
126+
"uint64": "usampler2D",
127+
# common dtype aliases
128+
"bool": "usampler2D",
129+
"int": "isampler2D",
130+
"uint": "usampler2D",
107131
},
108132
},
109133
"IMAGE_FORMAT": {
110-
"double": "rgba64f",
134+
"double": "rgba32f",
111135
"float": "rgba32f",
112136
"half": "rgba16f",
113-
"int": "rgba32i",
114-
"uint": "rgba32ui",
137+
# integer dtypes
115138
"int8": "rgba8i",
116139
"uint8": "rgba8ui",
117-
"bool": "rgba8ui",
118-
"short": "rgba16i",
140+
"int16": "rgba16i",
119141
"uint16": "rgba16ui",
142+
"int32": "rgba32i",
143+
"uint32": "rgba32ui",
144+
"int64": "rgba32i",
145+
"uint64": "rgba32ui",
146+
# common dtype aliases
147+
"bool": "rgba8ui",
148+
"int": "rgba32i",
149+
"uint": "rgba32ui",
120150
},
121151
}
122152

@@ -137,10 +167,12 @@ def buffer_scalar_type(dtype: str) -> str:
137167
return "float"
138168
elif dtype == "double":
139169
return "float64_t"
140-
elif dtype == "short":
141-
return "int16_t"
170+
# integer dtype alias conversion
142171
elif dtype == "bool":
143172
return "uint8_t"
173+
# we don't want to append _t for int32 or uint32 as int is already 32bit
174+
elif dtype == "int32" or dtype == "uint32":
175+
return "int" if dtype == "int32" else "uint"
144176
elif dtype[-1].isdigit():
145177
return dtype + "_t"
146178
return dtype
@@ -150,26 +182,29 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
150182
if n == 1:
151183
return buffer_scalar_type(dtype)
152184

153-
if dtype == "float":
154-
return f"vec{n}"
155-
if dtype == "uint":
156-
return f"uvec{n}"
157-
elif dtype == "half":
185+
if dtype == "half":
158186
return f"f16vec{n}"
187+
elif dtype == "float":
188+
return f"vec{n}"
159189
elif dtype == "double":
160-
return f"dvec{n}"
161-
elif dtype == "int":
162-
return f"ivec{n}"
163-
elif dtype == "short":
164-
return f"i16vec{n}"
165-
elif dtype == "uint16":
166-
return f"u16vec{n}"
190+
return f"f64vec{n}"
191+
# integer dtype
167192
elif dtype == "int8":
168193
return f"i8vec{n}"
169194
elif dtype == "uint8":
170195
return f"u8vec{n}"
171-
elif dtype == "bool":
172-
return f"u8vec{n}"
196+
elif dtype == "int16":
197+
return f"i16vec{n}"
198+
elif dtype == "uint16":
199+
return f"u16vec{n}"
200+
elif dtype == "int32" or dtype == "int":
201+
return f"ivec{n}"
202+
elif dtype == "uint32" or dtype == "uint":
203+
return f"uvec{n}"
204+
elif dtype == "int64":
205+
return f"i64vec{n}"
206+
elif dtype == "uint64":
207+
return f"u64vec{n}"
173208

174209
raise AssertionError(f"Invalid dtype: {dtype}")
175210

@@ -387,22 +422,19 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
387422
dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]
388423

389424
for dtype in dtype_list:
390-
nbit = None
391425
glsl_type = None
392426
if dtype == "half":
393-
nbit = "16bit"
394427
glsl_type = "float16"
395-
elif dtype == "short" or dtype == "int16" or dtype == "uint16":
396-
nbit = "16bit"
397-
glsl_type = "int16"
398-
elif dtype == "bool" or dtype == "int8" or dtype == "uint8":
399-
nbit = "8bit"
428+
elif dtype == "double":
429+
glsl_type = "float64"
430+
elif dtype in ["int8", "uint8"]:
400431
glsl_type = "int8"
401-
elif dtype == "double" or dtype == "float64":
402-
out_str += "#extension GL_ARB_gpu_shader_fp64 : require\n"
432+
elif dtype in ["int16", "uint16"]:
433+
glsl_type = "int16"
434+
elif dtype in ["int64", "uint64"]:
435+
glsl_type = "int64"
403436

404-
if nbit is not None and glsl_type is not None:
405-
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
437+
if glsl_type is not None:
406438
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
407439

408440
return out_str
@@ -658,6 +690,8 @@ def generateVariantCombinations(
658690

659691
elif "VALUE" in value:
660692
suffix = value.get("SUFFIX", value["VALUE"])
693+
if value["VALUE"] in ["int", "uint"]:
694+
raise ValueError(f"Use int32 or uint32 instead of {value['VALUE']}")
661695
param_values.append((param_name, suffix, value["VALUE"]))
662696

663697
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ arange:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: int32
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: int32
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int
20+
- VALUE: int32
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18-
- VALUE: short
19-
- VALUE: uint16
18+
- VALUE: int32
19+
- VALUE: int64
2020
shader_variants:
2121
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ buffer_to_nchw:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18-
- VALUE: short
19-
- VALUE: uint16
18+
- VALUE: int32
19+
- VALUE: int64
2020
shader_variants:
2121
- NAME: buffer_to_nchw

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: embedding

0 commit comments

Comments
 (0)