Skip to content

Commit 1b1e484

Browse files
committed
Update on "[ET-VK] Minor build graph change to improve model load time and memory."
A minor change in GraphBuilder to avoid creating a temp vector and reserve memory while building operator. Differential Revision: [D73864959](https://our.internmc.facebook.com/intern/diff/D73864959/) [ghstack-poisoned]
2 parents 2eff3c8 + 88087cc commit 1b1e484

File tree

12 files changed

+1146
-52
lines changed

12 files changed

+1146
-52
lines changed

backends/cadence/aot/reorder_ops.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,35 @@
3030

3131
# A list of ops that can be trivially quantized
3232
trivially_quantizable_ops_overloadpkt = {
33-
torch.ops.aten.slice_copy,
34-
torch.ops.aten.slice,
35-
torch.ops.aten.view_copy,
36-
torch.ops.aten.view,
37-
torch.ops.aten.clone,
38-
torch.ops.aten.transpose_copy,
39-
torch.ops.aten.transpose,
40-
torch.ops.aten.permute_copy,
41-
torch.ops.aten.permute,
42-
torch.ops.aten.squeeze_copy,
43-
torch.ops.aten.squeeze,
44-
torch.ops.aten.unsqueeze_copy,
45-
torch.ops.aten.unsqueeze,
46-
torch.ops.aten.chunk,
47-
torch.ops.aten.contiguous,
48-
torch.ops.aten.select_copy,
49-
exir_ops.edge.aten.slice_copy,
50-
exir_ops.edge.aten.view_copy,
33+
exir_ops.edge.aten.chunk,
5134
exir_ops.edge.aten.clone,
52-
exir_ops.edge.aten.transpose_copy,
35+
exir_ops.edge.aten.contiguous,
36+
exir_ops.edge.aten.expand_copy,
5337
exir_ops.edge.aten.permute_copy,
38+
exir_ops.edge.aten.select_copy,
39+
exir_ops.edge.aten.slice_copy,
5440
exir_ops.edge.aten.squeeze_copy,
55-
exir_ops.edge.aten.unsqueeze_copy,
41+
exir_ops.edge.aten.transpose_copy,
5642
exir_ops.edge.aten.unfold_copy,
57-
exir_ops.edge.aten.chunk,
58-
exir_ops.edge.aten.contiguous,
59-
exir_ops.edge.aten.select_copy,
43+
exir_ops.edge.aten.unsqueeze_copy,
44+
exir_ops.edge.aten.view_copy,
45+
torch.ops.aten.chunk,
46+
torch.ops.aten.clone,
47+
torch.ops.aten.contiguous,
48+
torch.ops.aten.expand_copy,
49+
torch.ops.aten.permute,
50+
torch.ops.aten.permute_copy,
51+
torch.ops.aten.select_copy,
52+
torch.ops.aten.slice,
53+
torch.ops.aten.slice_copy,
54+
torch.ops.aten.squeeze,
55+
torch.ops.aten.squeeze_copy,
56+
torch.ops.aten.transpose,
57+
torch.ops.aten.transpose_copy,
58+
torch.ops.aten.unsqueeze,
59+
torch.ops.aten.unsqueeze_copy,
60+
torch.ops.aten.view,
61+
torch.ops.aten.view_copy,
6062
}
6163

6264
# slice-equivalent ops

backends/cadence/hifi/operators/operators.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
_(uint8_t, Byte) \
1313
_(int8_t, Char)
1414

15+
using ::executorch::aten::optional;
16+
using ::executorch::aten::ScalarType;
17+
using ::executorch::aten::Tensor;
18+
using ::executorch::runtime::KernelRuntimeContext;
19+
1520
namespace cadence {
1621
namespace impl {
1722
namespace HiFi {
@@ -36,6 +41,32 @@ ::executorch::aten::Tensor& div_out_mode(
3641
::executorch::aten::optional<::executorch::aten::string_view> mode,
3742
::executorch::aten::Tensor& out);
3843

44+
void quantized_linear_out(
45+
__ET_UNUSED KernelRuntimeContext& ctx,
46+
const Tensor& in,
47+
const Tensor& weight,
48+
const Tensor& bias,
49+
int64_t in_zero_point,
50+
const Tensor& weight_zero_point,
51+
const Tensor& out_multiplier,
52+
const Tensor& out_shift,
53+
int64_t out_zero_point,
54+
__ET_UNUSED const optional<Tensor>& offset,
55+
Tensor& out);
56+
57+
void quantized_linear_per_tensor_out(
58+
__ET_UNUSED KernelRuntimeContext& ctx,
59+
const Tensor& in,
60+
const Tensor& weight,
61+
const Tensor& bias,
62+
int64_t in_zero_point,
63+
int64_t weight_zero_point,
64+
int64_t out_multiplier,
65+
int64_t out_shift,
66+
int64_t out_zero_point,
67+
__ET_UNUSED const optional<Tensor>& offset,
68+
Tensor& out);
69+
3970
} // namespace native
4071
} // namespace HiFi
4172
} // namespace impl

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
6060
// First iteration of reduce will have 32 threads sum up 64 elements.
6161
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
6262
// Thus thread utilization starts at 100%.
63-
#define SHARED_MEMORY_FACTOR 2
63+
#define SHARED_MEMORY_FACTOR 1
6464

65-
#define offset_pos_index(index) ((index) + ((index) >> 2))
65+
#define offset_pos_index(index) ((index) + ((index) >> 3))
6666

6767
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
6868

@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154154
if (all(lessThan(in_pos, out_limits))) {
155155
in_val = load_texel(t_in, in_pos);
156156
}
157-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157+
mean += in_val;
158158
}
159-
160-
reduce_input(width_stride, shared_idx_offset);
161-
mean += shared_input[offset_pos_index(shared_idx_offset)];
162159
}
163160

164-
mean /= width;
161+
shared_input[offset_pos_index(shared_idx)] = mean;
162+
reduce_input(width_stride, shared_idx_offset);
163+
mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165164

166165
memoryBarrierShared();
167166
barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178177
}
179178

180179
const VEC4_T delta = in_val - mean;
181-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180+
var += delta * delta;
182181
}
183-
184-
reduce_input(width_stride, shared_idx_offset);
185-
var += shared_input[offset_pos_index(shared_idx_offset)];
186182
}
187183

188-
var /= width;
184+
shared_input[offset_pos_index(shared_idx)] = var;
185+
reduce_input(width_stride, shared_idx_offset);
186+
var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189187

190188
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191189
VEC4_T offset = -rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226224

227225
const int in_pos_x_limit = out_limits[in_axis_map.x];
228226

227+
VEC4_T accum = VEC4_T(0);
229228
// Loop over the width in stride increments
230229
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231230
// Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244243
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245244
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
246245
}
247-
248-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246+
accum += in_val;
249247
}
250-
251-
reduce_input(width_stride, shared_idx_offset);
252-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253-
mean += val.x + val.y + val.z + val.w;
254248
}
255249

256-
mean /= width;
250+
shared_input[offset_pos_index(shared_idx)] = accum;
251+
reduce_input(width_stride, shared_idx_offset);
252+
VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean = (val.x + val.y + val.z + val.w) / width;
257254

258255
memoryBarrierShared();
259256
barrier();
260257

258+
VEC4_T delta2 = VEC4_T(0);
259+
261260
// Loop over the width in stride increments
262261
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263262
// Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278277
}
279278

280279
const VEC4_T delta = in_val - mean;
281-
const VEC4_T delta2 = delta * delta;
282-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280+
delta2 += delta * delta;
283281
}
284-
285-
reduce_input(width_stride, shared_idx_offset);
286-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287-
var += val.x + val.y + val.z + val.w;
288282
}
289283

290-
var /= width;
284+
shared_input[offset_pos_index(shared_idx)] = delta2;
285+
reduce_input(width_stride, shared_idx_offset);
286+
val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var = (val.x + val.y + val.z + val.w) / width;
291288

292289
T rstd = pow(var + epsilon, T(-0.5));
293290
T offset = -rstd * mean;

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __init__(self) -> None:
292292
] = {}
293293
self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {}
294294
self.module_name_config: dict[str, Optional[QuantizationConfig]] = {}
295+
# If specified, only quantize nodes that return true for the filter
296+
# function.
297+
self.filter_fn: Optional[Callable[[Node], bool]] = None
295298

296299
@classmethod
297300
def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
@@ -355,6 +358,14 @@ def set_module_name(
355358
self.module_name_config[module_name] = quantization_config
356359
return self
357360

361+
def set_filter_function(self, filter_fn: Callable[[Node], bool]):
362+
"""
363+
Set the filter function. We only quantize nodes that return True for
364+
the filter function.
365+
"""
366+
self.filter_fn = filter_fn
367+
return self
368+
358369
def transform_for_annotation(
359370
self, model: torch.fx.GraphModule
360371
) -> torch.fx.GraphModule:
@@ -378,17 +389,29 @@ def _annotate_all_patterns(
378389
if quantization_config is None:
379390
return model
380391

392+
# Create a combined filter function, which returns True only when
393+
# both filter_fn and self.filter_fn return True.
394+
def combined_filter_fn(n: Node) -> bool:
395+
combined_filter = [self.filter_fn, filter_fn]
396+
return all(f(n) for f in combined_filter if f is not None)
397+
381398
for pattern in self.SUPPORTED_PATTERNS:
382399
if operator_target and operator_target not in pattern.op_overloads:
383400
# if operator_target is specified, skip patterns that aren't
384401
# associated with that target
385402
continue
386403
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
387-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
404+
OP_TO_ANNOTATOR[pattern.name](
405+
model, quantization_config, combined_filter_fn
406+
)
388407
elif quantization_config.is_qat and pattern.is_qat:
389-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
408+
OP_TO_ANNOTATOR[pattern.name](
409+
model, quantization_config, combined_filter_fn
410+
)
390411
elif not quantization_config.input_activation.is_dynamic:
391-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
412+
OP_TO_ANNOTATOR[pattern.name](
413+
model, quantization_config, combined_filter_fn
414+
)
392415

393416
return model
394417

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,36 @@ def test_obs_sharing_ops(self):
297297
]
298298
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
299299

300+
def test_set_filter_fn(self):
301+
quantizer = XNNPACKQuantizer()
302+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
303+
quantizer.set_global(quantization_config)
304+
m_eager = TestHelperModules.TwoLinearModule().eval()
305+
306+
# Set the filter function so that the second linear is not quantized
307+
def filter_fn(n):
308+
return n.name != "linear_1"
309+
310+
quantizer.set_filter_function(filter_fn)
311+
312+
# Test with 2d inputs
313+
example_inputs_2d = (torch.randn(9, 8),)
314+
node_occurrence = {
315+
# input and output of the first linear op will be (de)quantized
316+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
317+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
318+
# quantize_per_channel for weights are const propagated
319+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
320+
# weight for the first linear will be dequantized
321+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
322+
}
323+
self._test_quantizer(
324+
m_eager,
325+
example_inputs_2d,
326+
quantizer,
327+
node_occurrence,
328+
)
329+
300330
def test_set_module_name(self):
301331
class Sub(torch.nn.Module):
302332
def __init__(self) -> None:

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ public enum ModelType {
1414
LLAMA_3_2,
1515
LLAVA_1_5,
1616
LLAMA_GUARD_3,
17+
QWEN_3,
1718
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
2929
case LLAMA_3:
3030
case LLAMA_3_1:
3131
case LLAMA_3_2:
32+
case QWEN_3:
3233
default:
3334
return TEXT_MODEL;
3435
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public static String getSystemPromptTemplate(ModelType modelType) {
2525
+ "<|eot_id|>";
2626
case LLAVA_1_5:
2727
return "USER: ";
28+
case QWEN_3:
29+
return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n";
2830
default:
2931
return SYSTEM_PLACEHOLDER;
3032
}
@@ -42,6 +44,14 @@ public static String getUserPromptTemplate(ModelType modelType) {
4244
+ "<|start_header_id|>assistant<|end_header_id|>";
4345

4446
case LLAVA_1_5:
47+
case QWEN_3:
48+
return "<|im_start|>user\n"
49+
+ USER_PLACEHOLDER
50+
+ "<|im_end|>\n"
51+
+ "<|im_start|>assistant\n"
52+
+ "<think>\n"
53+
+ "\n"
54+
+ "</think>\n\n\n";
4555
default:
4656
return USER_PLACEHOLDER;
4757
}
@@ -69,6 +79,8 @@ public static String getStopToken(ModelType modelType) {
6979
return "<|eot_id|>";
7080
case LLAVA_1_5:
7181
return "</s>";
82+
case QWEN_3:
83+
return "<|endoftext|>";
7284
default:
7385
return "";
7486
}

examples/models/llama/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the
308308
309309
To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON`
310310
311+
If you an error about "RE2 failed to compile pattern with lookahead:...SUPPORT_REGEX_LOOKAHEAD=ON", add "-DSUPPORT_REGEX_LOOKAHEAD=ON" when building the runner.
312+
311313
## Step 4: Run benchmark on Android phone
312314
313315
**1. Build llama runner binary for Android**

0 commit comments

Comments
 (0)