Skip to content

Commit e9787ec

Browse files
author
morelos
committed
Update base for Update on "[ET] enabling half dtype output for dequantization and making logic consistent"
Enabling half dtype output and making dequantization logic consistent between per_tensor and per_token as it is currently prone to integer overflows on one over the other Differential Revision: [D76289181](https://our.internmc.facebook.com/intern/diff/D76289181/) [ghstack-poisoned]
1 parent 8acaa6e commit e9787ec

40 files changed

+134
-1940
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 37 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -56,97 +56,52 @@
5656
TYPE_MAPPINGS: Dict[str, Any] = {
5757
"IMAGE_T": {
5858
3: {
59-
"double": "image3D",
6059
"float": "image3D",
6160
"half": "image3D",
62-
# integer dtypes
61+
"int": "iimage3D",
62+
"uint": "uimage3D",
6363
"int8": "iimage3D",
6464
"uint8": "uimage3D",
65-
"int16": "iimage3D",
66-
"uint16": "uimage3D",
67-
"int32": "iimage3D",
68-
"uint32": "uimage3D",
69-
"int64": "iimage3D",
70-
"uint64": "uimage3D",
71-
# common dtype aliases
7265
"bool": "uimage3D",
73-
"int": "iimage3D",
74-
"uint": "uimage3D",
7566
},
7667
2: {
77-
"double": "image2D",
7868
"float": "image2D",
7969
"half": "image2D",
80-
# integer dtypes
70+
"int": "iimage2D",
71+
"uint": "uimage2D",
8172
"int8": "iimage2D",
8273
"uint8": "uimage2D",
83-
"int16": "iimage2D",
84-
"uint16": "uimage2D",
85-
"int32": "iimage2D",
86-
"uint32": "uimage2D",
87-
"int64": "iimage2D",
88-
"uint64": "uimage2D",
89-
# common dtype aliases
9074
"bool": "uimage2D",
91-
"int": "iimage2D",
92-
"uint": "uimage2D",
9375
},
9476
},
9577
"SAMPLER_T": {
9678
3: {
97-
"double": "sampler3D",
9879
"float": "sampler3D",
9980
"half": "sampler3D",
100-
# integer dtypes
81+
"int": "isampler3D",
82+
"uint": "usampler3D",
10183
"int8": "isampler3D",
10284
"uint8": "usampler3D",
103-
"int16": "isampler3D",
104-
"uint16": "usampler3D",
105-
"int32": "isampler3D",
106-
"uint32": "usampler3D",
107-
"int64": "isampler3D",
108-
"uint64": "usampler3D",
109-
# common dtype aliases
11085
"bool": "usampler3D",
111-
"int": "isampler3D",
112-
"uint": "usampler3D",
11386
},
11487
2: {
115-
"double": "sampler2D",
11688
"float": "sampler2D",
11789
"half": "sampler2D",
118-
# integer dtypes
90+
"int": "isampler2D",
91+
"uint": "usampler2D",
11992
"int8": "isampler2D",
12093
"uint8": "usampler2D",
121-
"int16": "isampler2D",
122-
"uint16": "usampler2D",
123-
"int32": "isampler2D",
124-
"uint32": "usampler2D",
125-
"int64": "isampler2D",
126-
"uint64": "usampler2D",
127-
# common dtype aliases
12894
"bool": "usampler2D",
129-
"int": "isampler2D",
130-
"uint": "usampler2D",
13195
},
13296
},
13397
"IMAGE_FORMAT": {
134-
"double": "rgba32f",
13598
"float": "rgba32f",
13699
"half": "rgba16f",
137-
# integer dtypes
100+
"int": "rgba32i",
101+
"uint": "rgba32ui",
138102
"int8": "rgba8i",
139103
"uint8": "rgba8ui",
140-
"int16": "rgba16i",
141-
"uint16": "rgba16ui",
142-
"int32": "rgba32i",
143-
"uint32": "rgba32ui",
144-
"int64": "rgba32i",
145-
"uint64": "rgba32ui",
146-
# common dtype aliases
147104
"bool": "rgba8ui",
148-
"int": "rgba32i",
149-
"uint": "rgba32ui",
150105
},
151106
}
152107

@@ -163,47 +118,33 @@ def define_variable(name: str) -> str:
163118
def buffer_scalar_type(dtype: str) -> str:
164119
if dtype == "half":
165120
return "float16_t"
166-
elif dtype == "float":
167-
return "float"
168-
elif dtype == "double":
169-
return "float64_t"
170-
# integer dtype alias conversion
121+
elif dtype[-1] == "8":
122+
return dtype + "_t"
171123
elif dtype == "bool":
172124
return "uint8_t"
173-
# we don't want to append _t for int32 or uint32 as int is already 32bit
174-
elif dtype == "int32" or dtype == "uint32":
175-
return "int" if dtype == "int32" else "uint"
176-
elif dtype[-1].isdigit():
177-
return dtype + "_t"
178125
return dtype
179126

180127

181128
def buffer_gvec_type(dtype: str, n: int) -> str:
182129
if n == 1:
183130
return buffer_scalar_type(dtype)
184131

185-
dtype_map = {
186-
"half": f"f16vec{n}",
187-
"float": f"vec{n}",
188-
"double": f"vec{n}", # No 64bit support in GLSL
189-
"int8": f"i8vec{n}",
190-
"uint8": f"u8vec{n}",
191-
"int16": f"i16vec{n}",
192-
"uint16": f"u16vec{n}",
193-
"int32": f"ivec{n}",
194-
"int": f"ivec{n}",
195-
"uint32": f"uvec{n}",
196-
"uint": f"uvec{n}",
197-
"int64": f"ivec{n}", # No 64bit support in GLSL
198-
"uint64": f"uvec{n}", # No 64bit support in GLSL
199-
"bool": f"u8vec{n}",
200-
}
201-
202-
vector_type = dtype_map.get(dtype)
203-
if vector_type is None:
204-
raise AssertionError(f"Invalid dtype: {dtype}")
205-
206-
return vector_type
132+
if dtype == "float":
133+
return f"vec{n}"
134+
if dtype == "uint":
135+
return f"uvec{n}"
136+
elif dtype == "half":
137+
return f"f16vec{n}"
138+
elif dtype == "int":
139+
return f"ivec{n}"
140+
elif dtype == "int8":
141+
return f"i8vec{n}"
142+
elif dtype == "uint8":
143+
return f"u8vec{n}"
144+
elif dtype == "bool":
145+
return f"u8vec{n}"
146+
147+
raise AssertionError(f"Invalid dtype: {dtype}")
207148

208149

209150
def texel_type(dtype: str) -> str:
@@ -419,19 +360,20 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
419360
dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]
420361

421362
for dtype in dtype_list:
363+
nbit = None
422364
glsl_type = None
423365
if dtype == "half":
366+
nbit = "16bit"
424367
glsl_type = "float16"
425-
elif dtype == "double":
426-
glsl_type = "float64"
427-
elif dtype in ["int8", "uint8", "bool"]:
428-
glsl_type = "int8"
429-
elif dtype in ["int16", "uint16"]:
368+
elif dtype == "int16" or dtype == "uint16":
369+
nbit = "16bit"
430370
glsl_type = "int16"
431-
elif dtype in ["int64", "uint64"]:
432-
glsl_type = "int64"
371+
elif dtype == "int8" or dtype == "uint8" or dtype == "bool":
372+
nbit = "8bit"
373+
glsl_type = "int8"
433374

434-
if glsl_type is not None:
375+
if nbit is not None and glsl_type is not None:
376+
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
435377
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
436378

437379
return out_str
@@ -687,10 +629,6 @@ def generateVariantCombinations(
687629

688630
elif "VALUE" in value:
689631
suffix = value.get("SUFFIX", value["VALUE"])
690-
if value["VALUE"] in ["int", "uint"]:
691-
raise ValueError(
692-
f"Use int32 or uint32 instead of {value['VALUE']}"
693-
)
694632
param_values.append((param_name, suffix, value["VALUE"]))
695633

696634
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
arange:
88
parameter_names_with_default_values:
99
NDIM: 3
10-
DTYPE: int32
10+
DTYPE: int
1111
STORAGE: texture3d
1212
PACKING: C_packed
1313
generate_variant_forall:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int32
17+
- VALUE: int
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int32
16+
- VALUE: int
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int32
20+
- VALUE: int
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: double
15+
- VALUE: int
1616
- VALUE: int8
1717
- VALUE: uint8
18-
- VALUE: int32
1918
shader_variants:
2019
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ buffer_to_nchw:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: double
16+
- VALUE: int
1717
- VALUE: int8
1818
- VALUE: uint8
19-
- VALUE: int32
2019
shader_variants:
2120
- NAME: buffer_to_nchw
2221
- NAME: buffer_to_nchw_no_pc

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int32
10+
- VALUE: int
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int32
10+
- VALUE: int
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int32
10+
- VALUE: int
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int32
10+
- VALUE: int
1111
shader_variants:
1212
- NAME: embedding

0 commit comments

Comments
 (0)