Skip to content

Commit 5bc6d0c

Browse files
committed
Update on "[ET-VK] Replace Uniform buffers with push constants for view op"
This diff replaces uniform buffers with push constants for view op in the Vulkan backend of Executorch. The changes include updating the GLSL code to use push constants instead of uniform buffers and updating the C++ code to pass the sizes as push constants to the shader. Differential Revision: [D66733658](https://our.internmc.facebook.com/intern/diff/D66733658/) [ghstack-poisoned]
2 parents 501484e + a1adfca commit 5bc6d0c

File tree

30 files changed

+985
-191
lines changed

30 files changed

+985
-191
lines changed

backends/arm/test/conftest.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import shutil
1212
import subprocess
1313
import sys
14-
from enum import auto, Enum
1514
from typing import Any
1615

1716
import pytest
@@ -22,30 +21,24 @@
2221
"""
2322

2423

25-
class arm_test_options(Enum):
26-
quantize_io = auto()
27-
corstone_fvp = auto()
28-
fast_fvp = auto()
29-
30-
31-
_test_options: dict[arm_test_options, Any] = {}
32-
3324
# ==== Pytest hooks ====
3425

3526

3627
def pytest_configure(config):
28+
pytest._test_options = {}
29+
3730
if config.option.arm_quantize_io:
3831
_load_libquantized_ops_aot_lib()
39-
_test_options[arm_test_options.quantize_io] = True
32+
pytest._test_options["quantize_io"] = True
4033
if config.option.arm_run_corstoneFVP:
4134
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
4235
corstone320_exists = shutil.which("FVP_Corstone_SSE-320")
4336
if not (corstone300_exists and corstone320_exists):
4437
raise RuntimeError(
4538
"Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed."
4639
)
47-
_test_options[arm_test_options.corstone_fvp] = True
48-
_test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
40+
pytest._test_options["corstone_fvp"] = True
41+
pytest._test_options["fast_fvp"] = config.option.fast_fvp
4942
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5043

5144

@@ -131,9 +124,7 @@ def expectedFailureOnFVP(test_item):
131124
# ==== End of Custom Pytest decorators =====
132125

133126

134-
def is_option_enabled(
135-
option: str | arm_test_options, fail_if_not_enabled: bool = False
136-
) -> bool:
127+
def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
137128
"""
138129
Returns whether an option is successfully enabled, i.e. if the flag was
139130
given to pytest and the necessary requirements are available.
@@ -144,10 +135,8 @@ def is_option_enabled(
144135
The optional parameter 'fail_if_not_enabled' makes the function raise
145136
a RuntimeError instead of returning False.
146137
"""
147-
if isinstance(option, str):
148-
option = arm_test_options[option.lower()]
149138

150-
if option in _test_options and _test_options[option]:
139+
if option in pytest._test_options and pytest._test_options[option]:
151140
return True
152141
else:
153142
if fail_if_not_enabled:
@@ -156,15 +145,15 @@ def is_option_enabled(
156145
return False
157146

158147

159-
def get_option(option: arm_test_options) -> Any | None:
148+
def get_option(option: str) -> Any | None:
160149
"""
161150
Returns the value of an pytest option if it is set, otherwise None.
162151
163152
Args:
164-
option (arm_test_options): The option to check for.
153+
option (str): The option to check for.
165154
"""
166-
if option in _test_options:
167-
return _test_options[option]
155+
if option in pytest._test_options:
156+
return pytest._test_options[option]
168157
return None
169158

170159

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,11 @@
156156
("two_dw_conv2d", two_dw_conv2d),
157157
]
158158

159-
testsuite_conv2d_u85 = [
159+
testsuite_conv2d_u85_xfails = [
160160
("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1),
161161
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
162162
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
163163
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
164-
]
165-
166-
testsuite_conv2d_u85_xfails = [
167164
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
168165
("two_dw_conv2d", two_dw_conv2d),
169166
]
@@ -287,7 +284,7 @@ def test_dw_conv1d_u55_BI(
287284
model.get_inputs(),
288285
)
289286

290-
@parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85)
287+
@parameterized.expand(testsuite_conv1d[2:])
291288
def test_dw_conv_u85_BI(
292289
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
293290
):
@@ -299,8 +296,12 @@ def test_dw_conv_u85_BI(
299296
model.get_inputs(),
300297
)
301298

299+
testsuite_conv2d_u85_xfails.remove(
300+
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1)
301+
) # Works
302+
302303
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
303-
@parameterized.expand(testsuite_conv2d_u85_xfails)
304+
@parameterized.expand(testsuite_conv2d_u85_xfails + testsuite_conv1d[:2])
304305
@conftest.expectedFailureOnFVP
305306
def test_dw_conv_u85_BI_xfails(
306307
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False

backends/arm/test/ops/test_div.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,8 @@ def test_div_tosa_BI(
183183
test_data = (input_, other_)
184184
self._test_div_tosa_BI_pipeline(self.Div(), test_data)
185185

186-
@parameterized.expand(test_data_suite[:2])
187-
def test_div_u55_BI(
188-
self,
189-
test_name: str,
190-
input_: Union[torch.Tensor, torch.types.Number],
191-
other_: Union[torch.Tensor, torch.types.Number],
192-
rounding_mode: Optional[str] = None,
193-
):
194-
test_data = (input_, other_)
195-
self._test_div_ethos_BI_pipeline(
196-
self.Div(), common.get_u55_compile_spec(), test_data
197-
)
198-
199186
# Numerical issues on FVP likely due to mul op, MLETORCH-521
200-
@parameterized.expand(test_data_suite[2:])
187+
@parameterized.expand(test_data_suite)
201188
@conftest.expectedFailureOnFVP
202189
def test_div_u55_BI_xfails(
203190
self,
@@ -211,21 +198,8 @@ def test_div_u55_BI_xfails(
211198
self.Div(), common.get_u55_compile_spec(), test_data
212199
)
213200

214-
@parameterized.expand(test_data_suite[:2])
215-
def test_div_u85_BI(
216-
self,
217-
test_name: str,
218-
input_: Union[torch.Tensor, torch.types.Number],
219-
other_: Union[torch.Tensor, torch.types.Number],
220-
rounding_mode: Optional[str] = None,
221-
):
222-
test_data = (input_, other_)
223-
self._test_div_ethos_BI_pipeline(
224-
self.Div(), common.get_u85_compile_spec(), test_data
225-
)
226-
227201
# Numerical issues on FVP likely due to mul op, MLETORCH-521
228-
@parameterized.expand(test_data_suite[2:])
202+
@parameterized.expand(test_data_suite)
229203
@conftest.expectedFailureOnFVP
230204
def test_div_u85_BI_xfails(
231205
self,

backends/arm/test/ops/test_mul.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def test_mul_tosa_BI(
152152
test_data = (input_, other_)
153153
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)
154154

155+
# Numerical issues on FVP, MLETORCH-521
155156
@parameterized.expand(test_data_sute)
157+
@conftest.expectedFailureOnFVP
156158
def test_mul_u55_BI(
157159
self,
158160
test_name: str,
@@ -164,7 +166,10 @@ def test_mul_u55_BI(
164166
common.get_u55_compile_spec(), self.Mul(), test_data
165167
)
166168

167-
@parameterized.expand(test_data_sute)
169+
# Numerical issues on FVP, MLETORCH-521
170+
# test_data_sute[0] works on U85
171+
@parameterized.expand(test_data_sute[1:])
172+
@conftest.expectedFailureOnFVP
168173
def test_mul_u85_BI(
169174
self,
170175
test_name: str,

backends/arm/test/runner_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import torch
1919

20-
from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled
20+
from executorch.backends.arm.test.conftest import is_option_enabled
2121

2222
from torch.export import ExportedProgram
2323
from torch.fx.node import Node
@@ -251,7 +251,7 @@ def run_corstone(
251251
cmd_line += f" -i {input_path}"
252252

253253
ethos_u_extra_args = ""
254-
if is_option_enabled(arm_test_options.fast_fvp):
254+
if is_option_enabled("fast_fvp"):
255255
ethos_u_extra_args = ethos_u_extra_args + "--fast"
256256

257257
command_args = {

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- op: _softmax.out
2121
kernels:
2222
- arg_meta: null
23-
kernel_name: cadence::impl::G3::softmax_out
23+
kernel_name: cadence::impl::G3::_softmax_out
2424

2525
- op: add.out
2626
kernels:
@@ -71,7 +71,7 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
74+
7575
- op: mul.Scalar_out
7676
kernels:
7777
- arg_meta: null
@@ -111,8 +111,21 @@
111111
kernels:
112112
- arg_meta: null
113113
kernel_name: torch::executor::where_out
114-
114+
115115
- op: native_layer_norm.out
116116
kernels:
117117
- arg_meta: null
118-
kernel_name: cadence::impl::G3::native_layer_norm_out
118+
kernel_name: cadence::impl::G3::native_layer_norm_out
119+
120+
# custom ops
121+
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122+
variants: function
123+
kernels:
124+
- arg_meta: null
125+
kernel_name: cadence::impl::G3::native::quantize_per_tensor_out
126+
127+
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
128+
variants: function
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::G3::native::dequantize_per_tensor_out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
oncall("odai_jarvis")
4+
5+
define_common_targets()

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void check_dequantize_per_tensor_args(
5252
ET_CHECK_MSG(
5353
input.scalar_type() == ScalarType::Byte ||
5454
input.scalar_type() == ScalarType::Char ||
55-
input.scalar_type() == ScalarType::Bits16 ||
55+
input.scalar_type() == ScalarType::UInt16 ||
5656
input.scalar_type() == ScalarType::Short ||
5757
input.scalar_type() == (ScalarType)Ushort ||
5858
input.scalar_type() == (ScalarType)Bits4 ||
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383
} // namespace
8484

8585
/* Local function which calls the kernels based on the input datatype */
86-
void Dequantize_impl(
86+
void dequantize_impl(
8787
Tensor& out,
8888
const Tensor& input,
8989
float* scale_data,
@@ -211,7 +211,7 @@ void Dequantize_impl(
211211
break;
212212
switch (input.scalar_type()) {
213213
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR);
214-
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
214+
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
215215
default:
216216
ET_CHECK_MSG(
217217
false,
@@ -302,7 +302,7 @@ void Dequantize_impl(
302302
break;
303303
switch (input.scalar_type()) {
304304
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL);
305-
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
305+
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
306306
default:
307307
ET_CHECK_MSG(
308308
false,
@@ -368,7 +368,7 @@ void Dequantize_impl(
368368
break;
369369
switch (input.scalar_type()) {
370370
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR);
371-
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
371+
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
372372
default:
373373
ET_CHECK_MSG(
374374
false,
@@ -459,7 +459,7 @@ void Dequantize_impl(
459459
break;
460460
switch (input.scalar_type()) {
461461
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL);
462-
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
462+
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
463463
default:
464464
ET_CHECK_MSG(
465465
false,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502
float scale_data = (float)scale;
503503
int zero_point_data = (int)zero_point;
504504

505-
Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
505+
dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
506506

507507
return out;
508508
}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620
for (int i = 0; i < scale.numel(); i++) {
621621
scale_data[i] = (float)scale_dt[i];
622622
}
623-
Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+
dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624

625625
return out;
626626
}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661
int64_t quant_min,
662662
int64_t quant_max,
663663
ScalarType dtype,
664-
exec_aten::optional<ScalarType> out_dtype,
665664
Tensor& out) {
666665
// TODO(larryliu): Add a context arg to the real op function and remove this
667666
// wrapper
668667
(void)context;
669668
return dequantize_per_tensor_out(
670-
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+
input,
670+
scale,
671+
zero_point,
672+
quant_min,
673+
quant_max,
674+
dtype,
675+
out.scalar_type(),
676+
out);
671677
}
672678

673679
Tensor& dequantize_per_tensor_tensor_args_out(
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770
} // namespace native
765771
} // namespace G3
766772
} // namespace impl
767-
} // namespace cadence
773+
} // namespace cadence

0 commit comments

Comments
 (0)