Skip to content

Commit ae98cec

Browse files
authored
Merge branch 'main' into toupstream/linear_int16
2 parents aba691a + a4e7475 commit ae98cec

File tree

92 files changed

+436
-874
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+436
-874
lines changed

.github/workflows/cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ jobs:
8989
9090
export-voxtral-cuda-artifact:
9191
name: export-voxtral-cuda-${{ matrix.quant.name }}
92+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
93+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
9294
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9395
permissions:
9496
id-token: write
@@ -166,6 +168,8 @@ jobs:
166168
167169
export-gemma3-cuda-artifact:
168170
name: export-gemma3-cuda-${{ matrix.quant.name }}
171+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
172+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
169173
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
170174
permissions:
171175
id-token: write

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
)
9493
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
9594
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@
8787
QuantizeOperatorArguments,
8888
RemoveNoopPass,
8989
ReplaceInfValues,
90-
ReplaceScalarWithTensorArgPassTOSABI,
91-
ReplaceScalarWithTensorArgPassTOSAMI,
90+
ReplaceScalarWithTensorByProfilePass,
9291
RetraceFoldedDtypesPass,
9392
RewriteConv2dPass,
9493
RewriteMatmulPass,
@@ -154,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
154153
with TosaLoweringContext(self.tosa_spec):
155154
return self(graph_module).graph_module
156155

157-
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
156+
def _tosa_INT_pipeline(
157+
self, exported_program: ExportedProgram, graph_module: GraphModule
158+
) -> GraphModule:
158159
self.add_pass(AnnotateOutputDimOrderPass())
159160
self.add_pass(FuseQuantizedActivationPass())
160161
self.add_pass(RemoveGetItemPass())
161162
self.add_pass(ConvertSplitToSlicePass())
162163
self.add_pass(ConvertMmToBmmPass())
163-
self.add_pass(
164-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
165-
)
164+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
166165
self.add_pass(ConvertFullLikeToFullPass())
167166
self.add_pass(ConvertToClampPass())
168167
self.add_pass(ConvertMinMaxPass())
@@ -172,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
172171
self.add_pass(CastToInt32Pass())
173172

174173
self.add_pass(CastBoolToInt8Pass())
175-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
176175
self.add_pass(AnnotateDecomposedMatmulPass())
177176
self.add_pass(QuantizeOperatorArguments())
178177
self.add_pass(ConvertELUParamsPass())
@@ -219,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
219218
self.add_pass(InsertRescalePass())
220219

221220
self.validate_constraints_mandatory()
222-
return self._transform(exported_program.graph_module)
221+
return self._transform(graph_module)
223222

224-
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
223+
def _tosa_FP_pipeline(
224+
self, exported_program: ExportedProgram, graph_module: GraphModule
225+
) -> GraphModule:
225226
self.add_pass(AnnotateOutputDimOrderPass())
226227
self.add_pass(DecomposeExpm1Pass())
227228
self.add_pass(DecomposeLogitPass())
@@ -242,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
242243
self.add_pass(DecomposeSinhPass())
243244
self.add_pass(DecomposeSignPass())
244245
self.add_pass(DecomposeDivTensorModePass())
245-
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
246+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
246247
self.add_pass(DecomposeEmbeddingPass())
247248
self.add_pass(FuseQuantizedActivationPass())
248249
self.add_pass(RemoveGetItemPass())
@@ -256,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
256257
self.add_pass(DecomposeLayerNormPass())
257258
self.add_pass(DecomposeBatchNormNoStatsPass())
258259
self.add_pass(DecomposeVarPass())
259-
self.add_pass(
260-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
261-
)
260+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
262261
self.add_pass(DecomposeNotEqualPass())
263262
self.add_pass(DecomposeDivPass())
264263
self.add_pass(DecomposeAddSubAlphaPass())
@@ -306,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
306305
self.add_pass(InsertRescalePass())
307306

308307
self.validate_constraints_mandatory()
309-
return self._transform(exported_program.graph_module)
308+
return self._transform(graph_module)
310309

311-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
310+
def transform_to_backend_pipeline(
311+
self, exported_program: ExportedProgram, graph_module: GraphModule
312+
):
312313
"""Apply passes before transforming program to backend"""
313314
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
314-
return self._tosa_FP_pipeline(exported_program)
315+
return self._tosa_FP_pipeline(exported_program, graph_module)
315316
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
316-
return self._tosa_INT_pipeline(exported_program)
317+
return self._tosa_INT_pipeline(exported_program, graph_module)
317318
else:
318319
raise NotImplementedError(
319320
f"No pass pipeline implemented for {self.tosa_spec=}"
@@ -335,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
335336
self.add_pass(DecomposeAddmmPass())
336337
self.add_pass(DecomposeDivTensorModePass())
337338
self.add_pass(DecomposeAddSubAlphaPass())
338-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
339+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
339340
self.add_pass(ScalarsToAttributePass())
340341
self.add_pass(DecomposeGroupNormPass())
341342
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_acosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAcoshPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorArgPassTOSAMI,
35+
ReplaceScalarWithTensorByProfilePass,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_asin_and_acos_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
2020
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
2121
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
22-
ReplaceScalarWithTensorArgPassTOSAMI,
22+
ReplaceScalarWithTensorByProfilePass,
2323
)
2424
from executorch.exir.dialects._ops import ops as exir_ops
2525
from executorch.exir.pass_base import ExportPass
@@ -71,7 +71,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
7171
ConvertFullLikeToFullPass,
7272
MatchArgRanksPass,
7373
MatchArgDtypePass,
74-
ReplaceScalarWithTensorArgPassTOSAMI,
74+
ReplaceScalarWithTensorByProfilePass,
7575
}
7676

7777
def _build_polynomial(

backends/arm/_passes/decompose_asinh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -32,7 +32,7 @@ class DecomposeAsinhPass(ArmPass):
3232
DecomposeSqrtPass,
3333
InsertTableOpsPass,
3434
MatchArgRanksPass,
35-
ReplaceScalarWithTensorArgPassTOSAMI,
35+
ReplaceScalarWithTensorByProfilePass,
3636
MatchArgDtypePass,
3737
}
3838

backends/arm/_passes/decompose_atan_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass):
4747
InsertTableOpsPass,
4848
MatchArgRanksPass,
4949
MatchArgDtypePass,
50-
ReplaceScalarWithTensorArgPassTOSAMI,
50+
ReplaceScalarWithTensorByProfilePass,
5151
}
5252

5353
def _rational_approximation(self, z, ops, meta):

backends/arm/_passes/decompose_atanh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass):
4343
InsertTableOpsPass,
4444
MatchArgRanksPass,
4545
MatchArgDtypePass,
46-
ReplaceScalarWithTensorArgPassTOSAMI,
46+
ReplaceScalarWithTensorByProfilePass,
4747
}
4848

4949
def call_operator(self, op, args, kwargs, meta):

backends/arm/_passes/decompose_cosh_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1111
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1212
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
13-
ReplaceScalarWithTensorArgPassTOSAMI,
13+
ReplaceScalarWithTensorByProfilePass,
1414
)
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass
@@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass):
3131
_passes_required_after: Set[Type[ExportPass]] = {
3232
InsertTableOpsPass,
3333
MatchArgRanksPass,
34-
ReplaceScalarWithTensorArgPassTOSAMI,
34+
ReplaceScalarWithTensorByProfilePass,
3535
MatchArgDtypePass,
3636
}
3737

backends/arm/_passes/decompose_expm1_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
1313
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
1414
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15-
ReplaceScalarWithTensorArgPassTOSAMI,
15+
ReplaceScalarWithTensorByProfilePass,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818
from executorch.exir.pass_base import ExportPass
@@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass):
8383
ConvertIntPowToMuls,
8484
InsertTableOpsPass,
8585
DecomposeDivPass,
86-
ReplaceScalarWithTensorArgPassTOSAMI,
86+
ReplaceScalarWithTensorByProfilePass,
8787
MatchArgDtypePass,
8888
MatchArgRanksPass,
8989
}

0 commit comments

Comments
 (0)