Skip to content

Commit b657769

Browse files
authored
Merge branch 'main' into gh/trivedivivek/108/orig
2 parents 4ec327e + 088815e commit b657769

File tree

9 files changed

+155
-72
lines changed

9 files changed

+155
-72
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,19 +274,46 @@ def get_per_channel_dtype(
274274

275275
return dtype
276276

277-
def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
277+
def get_quant_params(
278+
self, quant_params: QuantParams, xnn_graph: XNNGraph
279+
) -> XNNQuantParams:
278280
if quant_params.per_channel:
279281
scale = cast(torch.Tensor, quant_params.scale)
282+
buffer_idx = len(xnn_graph.constant_data)
283+
num_scales = scale.numel()
284+
285+
if quant_params.is_per_channel_group:
286+
scale = scale.to(torch.bfloat16)
287+
288+
num_bytes = scale.untyped_storage().nbytes()
289+
scale_array = ctypes.cast(
290+
scale.untyped_storage().data_ptr(),
291+
ctypes.POINTER(ctypes.c_char * num_bytes),
292+
).contents
293+
scale_name = hashlib.sha256(bytes(scale_array)).hexdigest()
294+
xnn_graph.constant_data.append(
295+
ConstantDataOffset(
296+
offset=UINT64_MAX, size=num_bytes, named_key=scale_name
297+
)
298+
)
299+
self._named_data_store.add_named_data(
300+
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
301+
)
302+
280303
if quant_params.is_per_channel_group:
281304
return PerChannelGroupQuant(
282-
scale=scale.flatten().tolist(),
305+
scale=[],
283306
channel_dim=quant_params.axis,
284307
group_size=quant_params.group_size,
308+
scale_buffer_idx=buffer_idx,
309+
num_scales=num_scales,
285310
)
286-
else: # per_channel quant
311+
else:
287312
return PerChannelQuant(
288-
scale=scale.tolist(),
313+
scale=[],
289314
channel_dim=quant_params.axis,
315+
scale_buffer_idx=buffer_idx,
316+
num_scales=num_scales,
290317
)
291318
elif quant_params.is_dynamic:
292319
# NB:
@@ -449,7 +476,7 @@ def define_tensor( # noqa: C901
449476
else XValue(
450477
xvalue_union=XNNQuantizedTensorValue(
451478
tensor_value=tvalue,
452-
quant_params=self.get_quant_params(quant_params),
479+
quant_params=self.get_quant_params(quant_params, xnn_graph),
453480
)
454481
)
455482
)

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,11 +421,32 @@ Error defineTensor(
421421
qparams->channel_dim(),
422422
dtype,
423423
zero_point);
424+
425+
const float* scale = qparams->scale()->data();
426+
427+
if (qparams->scale_buffer_idx() != 0) {
428+
// if scales are stored in named data, then retrieve it
429+
ConstantDataOffsetPtr scale_buffer_offset =
430+
flatbuffer_graph->constant_data()->Get(
431+
qparams->scale_buffer_idx());
432+
const std::string& data_name =
433+
scale_buffer_offset->named_key()->str();
434+
Result<FreeableBuffer> scale_buffer =
435+
named_data_map->get_data(data_name.c_str());
436+
ET_CHECK_OR_RETURN_ERROR(
437+
scale_buffer.ok(),
438+
Internal,
439+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
440+
data_name.c_str(),
441+
static_cast<uint32_t>(scale_buffer.error()));
442+
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
443+
freeable_buffers.push_back(std::move(scale_buffer.get()));
444+
}
424445
status = xnn_define_channelwise_quantized_tensor_value_v2(
425446
/*subgraph=*/subgraph_ptr,
426447
/*datatype=*/dtype,
427448
/*zero_point=*/zero_point,
428-
/*scale=*/qparams->scale()->data(),
449+
/*scale=*/scale,
429450
/*num_dims=*/tensor_value->num_dims(),
430451
/*channel_dim*/ qparams->channel_dim(),
431452
/*dims=*/dims_data.data(),
@@ -452,10 +473,24 @@ Error defineTensor(
452473

453474
// Block scales are preferably serialized as bf16 but can also be
454475
// serialized as fp32 for backwards compatability.
455-
if (qparams->scale_bf16() != nullptr) {
476+
if (qparams->scale_buffer_idx() != 0) {
477+
ConstantDataOffsetPtr scale_buffer_offset =
478+
flatbuffer_graph->constant_data()->Get(
479+
qparams->scale_buffer_idx());
480+
const std::string& data_name =
481+
scale_buffer_offset->named_key()->str();
482+
Result<FreeableBuffer> scale_buffer =
483+
named_data_map->get_data(data_name.c_str());
484+
ET_CHECK_OR_RETURN_ERROR(
485+
scale_buffer.ok(),
486+
Internal,
487+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
488+
data_name.c_str(),
489+
static_cast<uint32_t>(scale_buffer.error()));
456490
scale_data =
457-
static_cast<const uint16_t*>(qparams->scale_bf16()->data());
458-
scale_numel = qparams->scale_bf16()->size();
491+
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
492+
freeable_buffers.push_back(std::move(scale_buffer.get()));
493+
scale_numel = qparams->num_scales();
459494
} else {
460495
// Read fp32 scales, convert to bf16.
461496
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(

backends/xnnpack/serialization/runtime_schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ table Buffer {
4848
table PerChannelQuant {
4949
scale:[float];
5050
channel_dim:int;
51+
scale_buffer_idx: uint;
52+
num_scales: uint;
5153
}
5254

5355
table PerTokenDynamicQuant {
@@ -63,7 +65,9 @@ table PerChannelGroupQuant {
6365
scale:[float];
6466
channel_dim:int;
6567
group_size:int;
66-
scale_bf16:[ushort];
68+
scale_bf16:[ushort] (deprecated);
69+
scale_buffer_idx: uint;
70+
num_scales: uint;
6771
}
6872

6973
table XNNTensorValue {

backends/xnnpack/serialization/schema.fbs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,16 @@ table PerChannelGroupQuant {
4848
scale:[float];
4949
channel_dim:int;
5050
group_size:int;
51-
scale_bf16:[ushort];
51+
scale_bf16:[ushort] (deprecated);
52+
scale_buffer_idx: uint;
53+
num_scales: uint;
5254
}
5355

5456
table PerChannelQuant {
5557
scale:[float];
5658
channel_dim:int;
59+
scale_buffer_idx: uint;
60+
num_scales: uint;
5761
}
5862

5963
table PerTokenDynamicQuant {

backends/xnnpack/serialization/xnnpack_graph_schema.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,13 +425,23 @@ class XNNDatatype(IntEnum):
425425
class PerChannelQuant:
426426
scale: List[float]
427427
channel_dim: int
428+
scale_buffer_idx: int = -1
429+
num_scales: int = -1
430+
431+
432+
@dataclass
433+
class Buffer:
434+
storage: bytes
428435

429436

430437
@dataclass
431438
class PerChannelGroupQuant:
432439
scale: List[float]
433440
channel_dim: int
434441
group_size: int = 1
442+
scale_bf16: Optional[List[float]] = None
443+
scale_buffer_idx: int = -1
444+
num_scales: int = -1
435445

436446

437447
@dataclass

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,57 @@
77
# pyre-unsafe
88

99

10-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict, List, Tuple
1111

1212
import torch
1313
from torch.fx import GraphModule
1414
from torch.fx.interpreter import Interpreter
1515

1616

17+
class NodeFilter:
18+
"""
19+
A class used to filter nodes based on extensible criteria.
20+
Attributes:
21+
metadata_key (str): The key to look for in the node's metadata.
22+
op_type (str): The operation code to match.
23+
exclude_ops (List[str]): A list of operations to exclude from the filter.
24+
"""
25+
26+
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
27+
self.metadata_key = metadata_key
28+
self.op_type = op_type
29+
self.exclude_ops = exclude_ops
30+
31+
def matches(self, node: torch.fx.Node) -> bool:
32+
return (
33+
node.meta.get(self.metadata_key) is not None
34+
and node.op == self.op_type
35+
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
36+
)
37+
38+
1739
class IntermediateOutputCapturer(Interpreter):
40+
"""
41+
A class that captures intermediate outputs from a PyTorch graph module.
42+
Attributes:
43+
module (GraphModule): The graph module to capture outputs from.
44+
node_filters (List[NodeFilter]): A list of filters to apply to the nodes.
45+
"""
46+
1847
def __init__(self, module: GraphModule):
1948
super().__init__(module)
49+
self.node_filters = [
50+
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
51+
]
2052

53+
# Runs the graph module and captures the intermediate outputs.
2154
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
2255
captured_outputs = {}
2356

2457
def capture_run_node(n: torch.fx.Node) -> Any:
2558
result = super(IntermediateOutputCapturer, self).run_node(n)
26-
debug_handle = n.meta.get("debug_handle", None)
27-
if debug_handle is not None and n.op == "call_function":
59+
if all(filter.matches(n) for filter in self.node_filters):
60+
debug_handle = n.meta["debug_handle"]
2861
# Convert the debug handle to a tuple to use as a dictionary key
2962
key = (
3063
(debug_handle,)

devtools/inspector/tests/intermediate_output_capturer_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def test_capture_correct_outputs(self):
111111
(19,): torch.tensor([[3.6000, 4.5067]]),
112112
(20,): torch.tensor([[0.9734, 0.9891]]),
113113
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114-
(22,): torch.tensor([[0.9734]]),
115-
(23,): torch.tensor([[0.9891]]),
116114
}
117115
self.assertEqual(
118116
len(self.intermediate_outputs), len(expected_outputs_with_handles)

examples/models/llama/source_transformation/quantize.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
import re
99
from functools import partial
1010
from pathlib import Path
11-
from typing import Any, Dict, Optional
11+
from typing import Dict, Optional
1212

1313
import torch
1414
import torch.nn as nn
1515
import torch.nn.functional as F
1616

1717
from executorch.extension.llm.export.builder import DType
1818

19-
from sentencepiece import SentencePieceProcessor
20-
2119

2220
try:
2321
from fairseq2.nn.embedding import (
@@ -57,7 +55,7 @@ def quantize( # noqa C901
5755
5856
Args:
5957
model: The model to quantize.
60-
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
58+
qmode: The quantization mode, e.g. int8, 8da4w.
6159
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
6260
Also the dtype of the rest of the non-quantized compoents of the model.
6361
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
@@ -161,58 +159,6 @@ def quantize( # noqa C901
161159
if verbose:
162160
print("quantized model:", model)
163161
return model
164-
elif qmode == "8da4w-gptq":
165-
# Check for required args
166-
required_args: Optional[Any] = [
167-
group_size,
168-
calibration_limit,
169-
calibration_seq_length,
170-
]
171-
if any(arg is None for arg in required_args):
172-
raise Exception(
173-
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
174-
)
175-
if calibration_tasks is None:
176-
calibration_tasks = ["wikitext"]
177-
178-
try:
179-
# torchao 0.3+
180-
from torchao._models._eval import InputRecorder
181-
except ImportError:
182-
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore
183-
184-
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
185-
186-
if tokenizer_path is None:
187-
assert checkpoint_path is not None, "checkpoint_path must be specified"
188-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
189-
assert tokenizer_path.is_file(), tokenizer_path
190-
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
191-
model_file=str(tokenizer_path)
192-
)
193-
194-
inputs = (
195-
InputRecorder( # pyre-fixme[16]
196-
tokenizer,
197-
calibration_seq_length,
198-
None, # input_prep_func
199-
pad_calibration_inputs,
200-
model.vocab_size,
201-
)
202-
.record_inputs(
203-
calibration_tasks,
204-
calibration_limit,
205-
)
206-
.get_inputs()
207-
)
208-
209-
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
210-
blocksize,
211-
percdamp,
212-
group_size,
213-
) # TODO: separate computation and checkpoint dtype for GPTQ.
214-
model = gptq_quantizer.quantize(model, inputs)
215-
return model
216162
elif qmode == "vulkan_4w":
217163
from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
218164

extension/threadpool/cpuinfo_utils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
#include <executorch/runtime/platform/assert.h>
1818

19+
#if defined(__APPLE__) && defined(__aarch64__)
20+
#include <sys/sysctl.h>
21+
#endif
22+
1923
namespace executorch::extension::cpuinfo {
2024

2125
// Ignore revisions (last digit (4 LSBs))
@@ -33,6 +37,11 @@ bool is_non_performant_core(const struct cpuinfo_uarch_info* uarch_info) {
3337
case cpuinfo_uarch_cortex_a53:
3438
case cpuinfo_uarch_cortex_a510:
3539
case cpuinfo_uarch_icestorm:
40+
case cpuinfo_uarch_blizzard:
41+
case cpuinfo_uarch_sawtooth:
42+
case cpuinfo_uarch_coll_sawtooth:
43+
case cpuinfo_uarch_tupai_sawtooth:
44+
case cpuinfo_uarch_tahiti_sawtooth:
3645
return true;
3746
// This can be so many other cores.
3847
// Need to update this to better account for slow cores
@@ -167,6 +176,23 @@ uint32_t get_num_performant_cores() {
167176
// In one plua 12 while it has 2 little cores, the topology
168177
// reported in /sys/devices/system/cpu/cpu* /topology/core_siblings_list
169178
// report wrong topology which results in wront configratuon
179+
#if defined(__aarch64__) && defined(__APPLE__)
180+
// Copied from ATen/ParallelCommon.cpp
181+
// On Apple Silicon there are efficient and performance core
182+
// Restrict parallel algorithms to performance cores by default
183+
int32_t num_cores = -1;
184+
size_t num_cores_len = sizeof(num_cores);
185+
if (sysctlbyname(
186+
"hw.perflevel0.physicalcpu",
187+
&num_cores,
188+
&num_cores_len,
189+
nullptr,
190+
0) == 0) {
191+
if (num_cores > 1) {
192+
return static_cast<uint32_t>(num_cores);
193+
}
194+
}
195+
#endif
170196
return _get_num_performant_cores();
171197
}
172198
}

0 commit comments

Comments
 (0)