Skip to content

Commit a5e2abe

Browse files
committed
Update
[ghstack-poisoned]
2 parents 23f7286 + ed482bd commit a5e2abe

File tree

24 files changed

+361
-616
lines changed

24 files changed

+361
-616
lines changed

.ci/scripts/test_llava.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ run_and_verify() {
149149

150150
# verify result.txt
151151
RESULT=$(cat result.txt)
152-
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with"
152+
EXPECTED_PREFIX="ASSISTANT: The image captures a basketball game in progress, with"
153153

154154
if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
155155
echo "Expected result prefix: ${EXPECTED_PREFIX}"

.github/workflows/_link_check.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,29 @@ jobs:
5555
echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking."
5656
exit 1
5757
}
58+
59+
lint-file-size:
60+
if: ${{ github.event_name == 'pull_request' }}
61+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
62+
with:
63+
runner: linux.2xlarge
64+
docker-image: ci-image:executorch-ubuntu-22.04-linter
65+
submodules: false
66+
fetch-depth: 0
67+
ref: ${{ inputs.ref }}
68+
timeout: 30
69+
script: |
70+
chmod +x ./scripts/lint_file_size.sh
71+
./scripts/lint_file_size.sh $(
72+
if [ "${{ github.event_name }}" = "pull_request" ]; then
73+
echo "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
74+
else
75+
echo "${{ github.event.before }}" "${{ github.sha }}"
76+
fi
77+
) || {
78+
echo
79+
echo "File size lint failed: some files exceed the 1 MB limit."
80+
echo "If you really need large files, consider using Git LFS or storing them elsewhere."
81+
echo "If you really need to get unblocked and check in the file, can add it to the EXCEPTIONS list in scripts/lint_file_size.sh."
82+
exit 1
83+
}

backends/cadence/aot/ops_registrations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,19 @@
324324
"rope.out(Tensor input, Tensor sin_tensor, Tensor cos_tensor, Tensor? pos, *, Tensor(a!) out) -> Tensor(a!)"
325325
)
326326

327+
lib.define(
328+
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
329+
)
330+
lib.define(
331+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
332+
)
333+
lib.define(
334+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
335+
)
336+
lib.define(
337+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
338+
)
339+
327340
# Load/store with iDMA. These only exist before memory planning.
328341
# Post memory planning, we check that outputs/inputs for the load/store are in
329342
# DTCM and replace idma_load/idma_store with idma_copy.
@@ -2329,3 +2342,29 @@ def softmax_f32_f32_meta(
23292342
half_to_float: Optional[bool] = None,
23302343
) -> torch.Tensor:
23312344
return self.new_empty(self.size(), dtype=self.dtype)
2345+
2346+
2347+
@register_fake("cadence::quantized_softmax")
2348+
def quantized_softmax_meta(
2349+
input: torch.Tensor,
2350+
mask: torch.Tensor,
2351+
dim: int,
2352+
in_scale: torch.Tensor,
2353+
in_zero_point: torch.Tensor,
2354+
out_scale: torch.Tensor,
2355+
out_zero_point: torch.Tensor,
2356+
) -> torch.Tensor:
2357+
return input.new_empty(input.size(), dtype=input.dtype)
2358+
2359+
2360+
@register_fake("cadence::quantized_softmax.per_tensor")
2361+
def quantized_softmax_per_tensor_meta(
2362+
input: torch.Tensor,
2363+
mask: torch.Tensor,
2364+
dim: int,
2365+
in_scale: float,
2366+
in_zero_point: int,
2367+
out_scale: float,
2368+
out_zero_point: int,
2369+
) -> torch.Tensor:
2370+
return input.new_empty(input.size(), dtype=input.dtype)

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import Any, Dict, List, Tuple
9+
from typing import Any, cast, Dict, List, Tuple
1010

1111
import torch
12+
from executorch.backends.cadence.aot.compiler_utils import get_shape
1213
from executorch.backends.cadence.aot.quantizer.patterns import (
1314
AddmmPattern,
1415
AddPattern,
@@ -25,6 +26,7 @@
2526
MatmulPattern,
2627
ReluPattern0,
2728
ReluPattern1,
29+
SoftmaxPattern,
2830
)
2931
from executorch.backends.cadence.aot.quantizer.utils import (
3032
check_out_zero_point_is_min_range,
@@ -388,6 +390,73 @@ def get_args_and_kwargs_relu(
388390
return args, kwargs
389391

390392

393+
def get_args_and_kwargs_softmax(
394+
graph_module: GraphModule,
395+
inputs_inputs: List[fx.Node],
396+
dequants_inputs: List[fx.Node],
397+
quant_node: fx.Node,
398+
op_node: fx.Node,
399+
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
400+
# Make a dummy mask tensor
401+
mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0]))
402+
mask_shape = list(mask_shape) if mask_shape else []
403+
mask_shape[-1] = mask_shape[-1] // 16
404+
mask_tensor = graph_module.graph.call_function(
405+
torch.ops.aten.full.default,
406+
(
407+
mask_shape,
408+
0.0,
409+
),
410+
{"dtype": torch.int32},
411+
)
412+
# Make the scale and zero_point tensors
413+
in_scale_tensor = graph_module.graph.call_function(
414+
torch.ops.aten.full.default,
415+
(
416+
[1],
417+
dequants_inputs[0].args[1],
418+
),
419+
{"dtype": torch.float32},
420+
)
421+
in_zero_point_tensor = graph_module.graph.call_function(
422+
torch.ops.aten.full.default,
423+
(
424+
[1],
425+
dequants_inputs[0].args[2],
426+
),
427+
{"dtype": torch.int32},
428+
)
429+
out_scale_tensor = graph_module.graph.call_function(
430+
torch.ops.aten.full.default,
431+
(
432+
[1],
433+
quant_node.args[1],
434+
),
435+
{"dtype": torch.float32},
436+
)
437+
out_zero_point_tensor = graph_module.graph.call_function(
438+
torch.ops.aten.full.default,
439+
(
440+
[1],
441+
quant_node.args[2],
442+
),
443+
{"dtype": torch.int32},
444+
)
445+
446+
# Make the args and kwargs for the replacement op
447+
args = (
448+
inputs_inputs[0],
449+
mask_tensor,
450+
op_node.args[1],
451+
in_scale_tensor,
452+
in_zero_point_tensor,
453+
out_scale_tensor,
454+
out_zero_point_tensor,
455+
)
456+
kwargs = {}
457+
return args, kwargs
458+
459+
391460
class QuantFusion(ExportPass):
392461
# pyre-ignore[2]: Parameter `patterns` has no type specified
393462
def __init__(self, patterns) -> None:
@@ -543,6 +612,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
543612
dequants_inputs,
544613
quant_node,
545614
)
615+
elif isinstance(pattern, SoftmaxPattern):
616+
args, kwargs = get_args_and_kwargs_softmax(
617+
graph_module,
618+
inputs_inputs,
619+
dequants_inputs,
620+
quant_node,
621+
anchor_output_node,
622+
)
546623
fused = graph_module.graph.call_function(
547624
pattern.replacement_op(),
548625
args,

backends/cadence/aot/quantizer/patterns.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,25 @@ def partition_types(self) -> List[OpOverload]:
485485
class Conv2dReluPattern1(ConvReluBasePattern):
486486
def partition_types(self) -> List[OpOverload]:
487487
return [torch.ops.aten.conv2d.default, torch.ops.aten.relu_.default]
488+
489+
490+
class SoftmaxPattern(QuantizationPattern):
491+
492+
def partition_types(self) -> List[OpOverload]:
493+
return [torch.ops.aten._softmax.default]
494+
495+
def get_anchors(
496+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
497+
) -> PartitionAnchors:
498+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
499+
softmax_node = fused_partition[0].nodes[-1]
500+
501+
return PartitionAnchors(
502+
inputs=[(softmax_node, 0)],
503+
weights=[],
504+
biases=[],
505+
output=[(softmax_node,)],
506+
)
507+
508+
def replacement_op(self) -> OpOverload:
509+
return torch.ops.cadence.quantized_softmax.default

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
QuantizationPattern,
2828
ReluPattern0,
2929
ReluPattern1,
30+
SoftmaxPattern,
3031
)
3132
from executorch.backends.cadence.aot.quantizer.utils import (
3233
find_sequential_partitions_aten,
@@ -58,6 +59,15 @@
5859
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
5960
)
6061

62+
act_qspec_asym16s = QuantizationSpec(
63+
dtype=torch.int16,
64+
quant_min=-32768,
65+
quant_max=32767,
66+
qscheme=torch.per_tensor_affine,
67+
is_dynamic=False,
68+
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
69+
)
70+
6171
wgt_qspec_asym8s = QuantizationSpec(
6272
dtype=torch.int8,
6373
quant_min=-128,
@@ -92,6 +102,13 @@
92102
None,
93103
)
94104

105+
qconfig_A16 = QuantizationConfig(
106+
act_qspec_asym16s,
107+
act_qspec_asym16s,
108+
wgt_qspec_asym8s,
109+
None,
110+
)
111+
95112

96113
class CadenceAtenQuantizer(Quantizer):
97114
def __init__(
@@ -283,3 +300,15 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
283300
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
284301
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
285302
super().__init__(quantizers)
303+
304+
305+
class CadenceWithSoftmaxQuantizer(CadenceQuantizer):
306+
"""
307+
Quantizer including A16 softmax
308+
"""
309+
310+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
311+
if quantizers is None:
312+
quantizers = get_cadence_default_quantizers()
313+
quantizers.append(CadenceAtenQuantizer(SoftmaxPattern(), qconfig_A16))
314+
super().__init__(quantizers)

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,7 @@ class CommonReplacePasses:
22502250
ReplaceMMWithAddMMPass,
22512251
ReplaceRepeatWithCatPass,
22522252
ReplaceFullLikeWithFullPass,
2253+
ReplaceAtenConvolutionWithCadenceConvolutionPass,
22532254
]
22542255

22552256

@@ -2282,7 +2283,6 @@ class CadenceReplaceOpsInGraph:
22822283
RemoveNopSelectOpPass,
22832284
ReplacePadWithCatPass,
22842285
ReplaceConstantPadNdWithSlicePass,
2285-
ReplaceAtenConvolutionWithCadenceConvolutionPass,
22862286
ReplaceConvWithChannelLastConvPass,
22872287
ReplaceTrivialConvWithLinear,
22882288
ReplaceConvWithIm2RowAndLinear,

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,9 @@ class StaticAttentionIOManager {
589589
size_t prefill(
590590
executorch::runtime::Span<TokenT> tokens,
591591
executorch::runtime::Span<TokenT> input_buffer,
592-
executorch::runtime::Method& method) {
592+
executorch::runtime::Method& method,
593+
std::function<void(executorch::runtime::Span<const float>)>
594+
logits_callback = nullptr) {
593595
ET_LOG(Info, "Prefilling at position %zu", input_pos_);
594596
size_t input_len = input_buffer.size();
595597
auto& masks = get_mask(input_buffer.size());
@@ -610,6 +612,13 @@ class StaticAttentionIOManager {
610612
config_.k_cache_output_indices,
611613
config_.v_cache_output_indices,
612614
batch_len);
615+
if (logits_callback) {
616+
auto logits_tensor = method.get_output(0).toTensor();
617+
auto* logits = logits_tensor.const_data_ptr<float>();
618+
logits_callback(executorch::runtime::Span(
619+
logits,
620+
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
621+
}
613622
}
614623
return batch_len - 1;
615624
}

examples/models/llava/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
7979
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
8080
executorch_target_link_options_shared_lib(executorch)
8181

82-
# llava_runner library
83-
add_subdirectory(runner)
84-
85-
set(LINK_LIBS executorch gflags)
82+
set(LINK_LIBS executorch gflags extension_llm_runner)
8683
set(link_libraries ${LINK_LIBS})
8784
set(_srcs main.cpp)
8885

@@ -204,5 +201,5 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
204201
endif()
205202

206203
target_include_directories(llava_main PUBLIC ${_common_include_directories})
207-
target_link_libraries(llava_main PUBLIC llava_runner ${link_libraries})
204+
target_link_libraries(llava_main PUBLIC ${link_libraries})
208205
target_compile_options(llava_main PUBLIC ${_common_compile_options})

0 commit comments

Comments
 (0)