|
103 | 103 | RemoveNoopPass, |
104 | 104 | ReplaceInfAndLimitValuesPass, |
105 | 105 | ReplaceScalarWithTensorByProfilePass, |
| 106 | + RewriteBoolToFp32CastViaInt8Pass, |
106 | 107 | RewriteConvPass, |
107 | 108 | RewriteMatmulPass, |
108 | 109 | RewriteUpsamplePass, |
@@ -221,6 +222,7 @@ def _tosa_pipeline( |
221 | 222 | self.add_passes( |
222 | 223 | [ |
223 | 224 | FuseQuantizedActivationPass(), |
| 225 | + RewriteBoolToFp32CastViaInt8Pass(), |
224 | 226 | ConvertToClampPass(), |
225 | 227 | DecomposeTOSAUnsupportedClampPass(), |
226 | 228 | DecomposeGroupNormPass(), |
@@ -374,65 +376,65 @@ def transform_to_backend_pipeline( |
374 | 376 |
|
375 | 377 | def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
376 | 378 | # Preprocessing passes |
377 | | - self.add_pass(RemoveGraphAssertsPass()) |
| 379 | + self.add_pass(RemoveGraphAssertsPass(tfa_pass=True)) |
378 | 380 |
|
379 | 381 | # Transformation passes (pre scalar -> tensor) |
380 | 382 | self.add_passes( |
381 | 383 | [ |
382 | | - DecomposeSelectScatterPass(), |
383 | | - ConvertInt64ConstOpsToInt32Pass(), |
384 | | - ConvertInt64OutputOpsToInt32Pass(), |
385 | | - InsertInt32CastsAfterInt64PlaceholdersPass(), |
386 | | - DecomposeEmbeddingPass(), |
387 | | - DecomposeScaledDotProductAttentionPass(), |
388 | | - DecomposeRoundPass(), |
389 | | - DecomposeLogitPass(), |
390 | | - PromoteBoolOperandsPass(), |
391 | | - DecomposeSignPass(), |
392 | | - DecomposeAddmmPass(), |
393 | | - DecomposeRemainderPass(), |
394 | | - DecomposeFloorDividePass(), |
395 | | - DecomposeDivTensorModePass(), |
| 384 | + DecomposeSelectScatterPass(tfa_pass=True), |
| 385 | + ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), |
| 386 | + ConvertInt64OutputOpsToInt32Pass(tfa_pass=True), |
| 387 | + InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True), |
| 388 | + DecomposeEmbeddingPass(tfa_pass=True), |
| 389 | + DecomposeScaledDotProductAttentionPass(tfa_pass=True), |
| 390 | + DecomposeRoundPass(tfa_pass=True), |
| 391 | + DecomposeLogitPass(tfa_pass=True), |
| 392 | + PromoteBoolOperandsPass(tfa_pass=True), |
| 393 | + DecomposeSignPass(tfa_pass=True), |
| 394 | + DecomposeAddmmPass(tfa_pass=True), |
| 395 | + DecomposeRemainderPass(tfa_pass=True), |
| 396 | + DecomposeFloorDividePass(tfa_pass=True), |
| 397 | + DecomposeDivTensorModePass(tfa_pass=True), |
396 | 398 | ] |
397 | 399 | ) |
398 | 400 |
|
399 | 401 | # Scalars -> tensors |
400 | 402 | self.add_passes( |
401 | 403 | [ |
402 | | - ReplaceScalarWithTensorByProfilePass(), |
403 | | - ScalarsToAttributePass(), |
| 404 | + ReplaceScalarWithTensorByProfilePass(tfa_pass=True), |
| 405 | + ScalarsToAttributePass(tfa_pass=True), |
404 | 406 | ] |
405 | 407 | ) |
406 | 408 |
|
407 | 409 | # Transformation passes (post scalar removal) |
408 | 410 | self.add_passes( |
409 | 411 | [ |
410 | | - NormalizeWhileInitialArgsPass(use_exir_clone=False), |
411 | | - DecomposeAddSubAlphaPass(), |
412 | | - DecomposeGroupNormPass(), |
413 | | - DecomposeLayerNormPass(), |
414 | | - DecomposeVarPass(), |
415 | | - DecomposeMeanDimPass(graph_module, self.tosa_spec), |
416 | | - DecomposeNotEqualPass(), |
417 | | - DecomposeCosineSimilarityPass(), |
418 | | - DecomposeGluPass(), |
419 | | - DecomposeDivPass(), |
420 | | - DecomposeLeakyReLUPass(), |
421 | | - DecomposeLinalgVectorNormPass(), |
422 | | - DecomposeSqrtPass(), |
423 | | - DecomposeSiluPass(), |
424 | | - DecomposeAvgPool2dPass(), |
425 | | - DecomposeSoftmaxUnstablePass(), |
426 | | - DecomposeSoftmaxPass(), |
427 | | - ConvertMinMaxPass(), |
| 412 | + NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), |
| 413 | + DecomposeAddSubAlphaPass(tfa_pass=True), |
| 414 | + DecomposeGroupNormPass(tfa_pass=True), |
| 415 | + DecomposeLayerNormPass(tfa_pass=True), |
| 416 | + DecomposeVarPass(tfa_pass=True), |
| 417 | + DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), |
| 418 | + DecomposeNotEqualPass(tfa_pass=True), |
| 419 | + DecomposeCosineSimilarityPass(tfa_pass=True), |
| 420 | + DecomposeGluPass(tfa_pass=True), |
| 421 | + DecomposeDivPass(tfa_pass=True), |
| 422 | + DecomposeLeakyReLUPass(tfa_pass=True), |
| 423 | + DecomposeLinalgVectorNormPass(tfa_pass=True), |
| 424 | + DecomposeSqrtPass(tfa_pass=True), |
| 425 | + DecomposeSiluPass(tfa_pass=True), |
| 426 | + DecomposeAvgPool2dPass(tfa_pass=True), |
| 427 | + DecomposeSoftmaxUnstablePass(tfa_pass=True), |
| 428 | + DecomposeSoftmaxPass(tfa_pass=True), |
| 429 | + ConvertMinMaxPass(tfa_pass=True), |
428 | 430 | ] |
429 | 431 | ) |
430 | 432 |
|
431 | 433 | # Postprocessing passes |
432 | 434 | self.add_passes( |
433 | 435 | [ |
434 | | - ReplaceInfAndLimitValuesPass(), |
435 | | - DecomposeMaskedFillPass(), |
| 436 | + ReplaceInfAndLimitValuesPass(tfa_pass=True), |
| 437 | + DecomposeMaskedFillPass(tfa_pass=True), |
436 | 438 | ] |
437 | 439 | ) |
438 | 440 |
|
|
0 commit comments