Skip to content

Commit 9d32edf

Browse files
committed
Update on "[ET-VK] Using vector for storing ref_mapping_ in GraphBuilder to improve model load time and memory."
This diff changes GraphBuilder class to store ref id to value mapping as vector instead of unordered map, since maximum id is known and thus vector can be sized to store the map. Differential Revision: [D73969916](https://our.internmc.facebook.com/intern/diff/D73969916/) [ghstack-poisoned]
2 parents 7069b15 + af220d6 commit 9d32edf

File tree

12 files changed

+1146
-52
lines changed

12 files changed

+1146
-52
lines changed

backends/cadence/aot/reorder_ops.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,33 +30,35 @@
3030

3131
# A list of ops that can be trivially quantized
3232
trivially_quantizable_ops_overloadpkt = {
33-
torch.ops.aten.slice_copy,
34-
torch.ops.aten.slice,
35-
torch.ops.aten.view_copy,
36-
torch.ops.aten.view,
37-
torch.ops.aten.clone,
38-
torch.ops.aten.transpose_copy,
39-
torch.ops.aten.transpose,
40-
torch.ops.aten.permute_copy,
41-
torch.ops.aten.permute,
42-
torch.ops.aten.squeeze_copy,
43-
torch.ops.aten.squeeze,
44-
torch.ops.aten.unsqueeze_copy,
45-
torch.ops.aten.unsqueeze,
46-
torch.ops.aten.chunk,
47-
torch.ops.aten.contiguous,
48-
torch.ops.aten.select_copy,
49-
exir_ops.edge.aten.slice_copy,
50-
exir_ops.edge.aten.view_copy,
33+
exir_ops.edge.aten.chunk,
5134
exir_ops.edge.aten.clone,
52-
exir_ops.edge.aten.transpose_copy,
35+
exir_ops.edge.aten.contiguous,
36+
exir_ops.edge.aten.expand_copy,
5337
exir_ops.edge.aten.permute_copy,
38+
exir_ops.edge.aten.select_copy,
39+
exir_ops.edge.aten.slice_copy,
5440
exir_ops.edge.aten.squeeze_copy,
55-
exir_ops.edge.aten.unsqueeze_copy,
41+
exir_ops.edge.aten.transpose_copy,
5642
exir_ops.edge.aten.unfold_copy,
57-
exir_ops.edge.aten.chunk,
58-
exir_ops.edge.aten.contiguous,
59-
exir_ops.edge.aten.select_copy,
43+
exir_ops.edge.aten.unsqueeze_copy,
44+
exir_ops.edge.aten.view_copy,
45+
torch.ops.aten.chunk,
46+
torch.ops.aten.clone,
47+
torch.ops.aten.contiguous,
48+
torch.ops.aten.expand_copy,
49+
torch.ops.aten.permute,
50+
torch.ops.aten.permute_copy,
51+
torch.ops.aten.select_copy,
52+
torch.ops.aten.slice,
53+
torch.ops.aten.slice_copy,
54+
torch.ops.aten.squeeze,
55+
torch.ops.aten.squeeze_copy,
56+
torch.ops.aten.transpose,
57+
torch.ops.aten.transpose_copy,
58+
torch.ops.aten.unsqueeze,
59+
torch.ops.aten.unsqueeze_copy,
60+
torch.ops.aten.view,
61+
torch.ops.aten.view_copy,
6062
}
6163

6264
# slice-equivalent ops

backends/cadence/hifi/operators/operators.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
_(uint8_t, Byte) \
1313
_(int8_t, Char)
1414

15+
using ::executorch::aten::optional;
16+
using ::executorch::aten::ScalarType;
17+
using ::executorch::aten::Tensor;
18+
using ::executorch::runtime::KernelRuntimeContext;
19+
1520
namespace cadence {
1621
namespace impl {
1722
namespace HiFi {
@@ -36,6 +41,32 @@ ::executorch::aten::Tensor& div_out_mode(
3641
::executorch::aten::optional<::executorch::aten::string_view> mode,
3742
::executorch::aten::Tensor& out);
3843

44+
void quantized_linear_out(
45+
__ET_UNUSED KernelRuntimeContext& ctx,
46+
const Tensor& in,
47+
const Tensor& weight,
48+
const Tensor& bias,
49+
int64_t in_zero_point,
50+
const Tensor& weight_zero_point,
51+
const Tensor& out_multiplier,
52+
const Tensor& out_shift,
53+
int64_t out_zero_point,
54+
__ET_UNUSED const optional<Tensor>& offset,
55+
Tensor& out);
56+
57+
void quantized_linear_per_tensor_out(
58+
__ET_UNUSED KernelRuntimeContext& ctx,
59+
const Tensor& in,
60+
const Tensor& weight,
61+
const Tensor& bias,
62+
int64_t in_zero_point,
63+
int64_t weight_zero_point,
64+
int64_t out_multiplier,
65+
int64_t out_shift,
66+
int64_t out_zero_point,
67+
__ET_UNUSED const optional<Tensor>& offset,
68+
Tensor& out);
69+
3970
} // namespace native
4071
} // namespace HiFi
4172
} // namespace impl

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ const lowp int out_packed_dim = unhash_packed_dim(out_layout);
6060
// First iteration of reduce will have 32 threads sum up 64 elements.
6161
// Second iteration will have 32 threads sum up 16 elements from previous iteration and so on.
6262
// Thus thread utilization starts at 100%.
63-
#define SHARED_MEMORY_FACTOR 2
63+
#define SHARED_MEMORY_FACTOR 1
6464

65-
#define offset_pos_index(index) ((index) + ((index) >> 2))
65+
#define offset_pos_index(index) ((index) + ((index) >> 3))
6666

6767
shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];
6868

@@ -154,14 +154,13 @@ void reduce_non_packed_dim() {
154154
if (all(lessThan(in_pos, out_limits))) {
155155
in_val = load_texel(t_in, in_pos);
156156
}
157-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
157+
mean += in_val;
158158
}
159-
160-
reduce_input(width_stride, shared_idx_offset);
161-
mean += shared_input[offset_pos_index(shared_idx_offset)];
162159
}
163160

164-
mean /= width;
161+
shared_input[offset_pos_index(shared_idx)] = mean;
162+
reduce_input(width_stride, shared_idx_offset);
163+
mean = shared_input[offset_pos_index(shared_idx_offset)] / width;
165164

166165
memoryBarrierShared();
167166
barrier();
@@ -178,14 +177,13 @@ void reduce_non_packed_dim() {
178177
}
179178

180179
const VEC4_T delta = in_val - mean;
181-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
180+
var += delta * delta;
182181
}
183-
184-
reduce_input(width_stride, shared_idx_offset);
185-
var += shared_input[offset_pos_index(shared_idx_offset)];
186182
}
187183

188-
var /= width;
184+
shared_input[offset_pos_index(shared_idx)] = var;
185+
reduce_input(width_stride, shared_idx_offset);
186+
var = shared_input[offset_pos_index(shared_idx_offset)] / width;
189187

190188
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
191189
VEC4_T offset = -rstd * mean;
@@ -226,6 +224,7 @@ void reduce_packed_dim() {
226224

227225
const int in_pos_x_limit = out_limits[in_axis_map.x];
228226

227+
VEC4_T accum = VEC4_T(0);
229228
// Loop over the width in stride increments
230229
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
231230
// Read input in shared memory
@@ -244,20 +243,20 @@ void reduce_packed_dim() {
244243
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
245244
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
246245
}
247-
248-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
246+
accum += in_val;
249247
}
250-
251-
reduce_input(width_stride, shared_idx_offset);
252-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253-
mean += val.x + val.y + val.z + val.w;
254248
}
255249

256-
mean /= width;
250+
shared_input[offset_pos_index(shared_idx)] = accum;
251+
reduce_input(width_stride, shared_idx_offset);
252+
VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
253+
mean = (val.x + val.y + val.z + val.w) / width;
257254

258255
memoryBarrierShared();
259256
barrier();
260257

258+
VEC4_T delta2 = VEC4_T(0);
259+
261260
// Loop over the width in stride increments
262261
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
263262
// Read input in shared memory
@@ -278,16 +277,14 @@ void reduce_packed_dim() {
278277
}
279278

280279
const VEC4_T delta = in_val - mean;
281-
const VEC4_T delta2 = delta * delta;
282-
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
280+
delta2 += delta * delta;
283281
}
284-
285-
reduce_input(width_stride, shared_idx_offset);
286-
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
287-
var += val.x + val.y + val.z + val.w;
288282
}
289283

290-
var /= width;
284+
shared_input[offset_pos_index(shared_idx)] = delta2;
285+
reduce_input(width_stride, shared_idx_offset);
286+
val = shared_input[offset_pos_index(shared_idx_offset)];
287+
var = (val.x + val.y + val.z + val.w) / width;
291288

292289
T rstd = pow(var + epsilon, T(-0.5));
293290
T offset = -rstd * mean;

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __init__(self) -> None:
292292
] = {}
293293
self.module_type_config: dict[Callable, Optional[QuantizationConfig]] = {}
294294
self.module_name_config: dict[str, Optional[QuantizationConfig]] = {}
295+
# If specified, only quantize nodes that return true for the filter
296+
# function.
297+
self.filter_fn: Optional[Callable[[Node], bool]] = None
295298

296299
@classmethod
297300
def get_supported_quantization_configs(cls) -> list[QuantizationConfig]:
@@ -355,6 +358,14 @@ def set_module_name(
355358
self.module_name_config[module_name] = quantization_config
356359
return self
357360

361+
def set_filter_function(self, filter_fn: Callable[[Node], bool]):
362+
"""
363+
Set the filter function. We only quantize nodes that return True for
364+
the filter function.
365+
"""
366+
self.filter_fn = filter_fn
367+
return self
368+
358369
def transform_for_annotation(
359370
self, model: torch.fx.GraphModule
360371
) -> torch.fx.GraphModule:
@@ -378,17 +389,29 @@ def _annotate_all_patterns(
378389
if quantization_config is None:
379390
return model
380391

392+
# Create a combined filter function, which returns True only when
393+
# both filter_fn and self.filter_fn return True.
394+
def combined_filter_fn(n: Node) -> bool:
395+
combined_filter = [self.filter_fn, filter_fn]
396+
return all(f(n) for f in combined_filter if f is not None)
397+
381398
for pattern in self.SUPPORTED_PATTERNS:
382399
if operator_target and operator_target not in pattern.op_overloads:
383400
# if operator_target is specified, skip patterns that aren't
384401
# associated with that target
385402
continue
386403
if quantization_config.input_activation.is_dynamic and pattern.is_dynamic:
387-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
404+
OP_TO_ANNOTATOR[pattern.name](
405+
model, quantization_config, combined_filter_fn
406+
)
388407
elif quantization_config.is_qat and pattern.is_qat:
389-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
408+
OP_TO_ANNOTATOR[pattern.name](
409+
model, quantization_config, combined_filter_fn
410+
)
390411
elif not quantization_config.input_activation.is_dynamic:
391-
OP_TO_ANNOTATOR[pattern.name](model, quantization_config, filter_fn)
412+
OP_TO_ANNOTATOR[pattern.name](
413+
model, quantization_config, combined_filter_fn
414+
)
392415

393416
return model
394417

backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,36 @@ def test_obs_sharing_ops(self):
297297
]
298298
self._test_quantizer(m, example_inputs, quantizer, node_occurrence, node_list)
299299

300+
def test_set_filter_fn(self):
301+
quantizer = XNNPACKQuantizer()
302+
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
303+
quantizer.set_global(quantization_config)
304+
m_eager = TestHelperModules.TwoLinearModule().eval()
305+
306+
# Set the filter function so that the second linear is not quantized
307+
def filter_fn(n):
308+
return n.name != "linear_1"
309+
310+
quantizer.set_filter_function(filter_fn)
311+
312+
# Test with 2d inputs
313+
example_inputs_2d = (torch.randn(9, 8),)
314+
node_occurrence = {
315+
# input and output of the first linear op will be (de)quantized
316+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
317+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
318+
# quantize_per_channel for weights are const propagated
319+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
320+
# weight for the first linear will be dequantized
321+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
322+
}
323+
self._test_quantizer(
324+
m_eager,
325+
example_inputs_2d,
326+
quantizer,
327+
node_occurrence,
328+
)
329+
300330
def test_set_module_name(self):
301331
class Sub(torch.nn.Module):
302332
def __init__(self) -> None:

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ public enum ModelType {
1414
LLAMA_3_2,
1515
LLAVA_1_5,
1616
LLAMA_GUARD_3,
17+
QWEN_3,
1718
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
2929
case LLAMA_3:
3030
case LLAMA_3_1:
3131
case LLAMA_3_2:
32+
case QWEN_3:
3233
default:
3334
return TEXT_MODEL;
3435
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public static String getSystemPromptTemplate(ModelType modelType) {
2525
+ "<|eot_id|>";
2626
case LLAVA_1_5:
2727
return "USER: ";
28+
case QWEN_3:
29+
return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n";
2830
default:
2931
return SYSTEM_PLACEHOLDER;
3032
}
@@ -42,6 +44,14 @@ public static String getUserPromptTemplate(ModelType modelType) {
4244
+ "<|start_header_id|>assistant<|end_header_id|>";
4345

4446
case LLAVA_1_5:
47+
case QWEN_3:
48+
return "<|im_start|>user\n"
49+
+ USER_PLACEHOLDER
50+
+ "<|im_end|>\n"
51+
+ "<|im_start|>assistant\n"
52+
+ "<think>\n"
53+
+ "\n"
54+
+ "</think>\n\n\n";
4555
default:
4656
return USER_PLACEHOLDER;
4757
}
@@ -69,6 +79,8 @@ public static String getStopToken(ModelType modelType) {
6979
return "<|eot_id|>";
7080
case LLAVA_1_5:
7181
return "</s>";
82+
case QWEN_3:
83+
return "<|endoftext|>";
7284
default:
7385
return "";
7486
}

examples/models/llama/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the
308308
309309
To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON`
310310
311+
If you an error about "RE2 failed to compile pattern with lookahead:...SUPPORT_REGEX_LOOKAHEAD=ON", add "-DSUPPORT_REGEX_LOOKAHEAD=ON" when building the runner.
312+
311313
## Step 4: Run benchmark on Android phone
312314
313315
**1. Build llama runner binary for Android**

0 commit comments

Comments
 (0)