|
103 | 103 | RemoveNoopPass, |
104 | 104 | ReplaceInfAndLimitValuesPass, |
105 | 105 | ReplaceScalarWithTensorByProfilePass, |
| 106 | + RewriteBoolBitwiseNotToLogicalNotPass, |
106 | 107 | RewriteBoolToFp32CastViaInt8Pass, |
107 | 108 | RewriteConvPass, |
108 | 109 | RewriteMatmulPass, |
@@ -222,6 +223,7 @@ def _tosa_pipeline( |
222 | 223 | self.add_passes( |
223 | 224 | [ |
224 | 225 | FuseQuantizedActivationPass(), |
| 226 | + RewriteBoolBitwiseNotToLogicalNotPass(), |
225 | 227 | RewriteBoolToFp32CastViaInt8Pass(), |
226 | 228 | ConvertToClampPass(), |
227 | 229 | DecomposeTOSAUnsupportedClampPass(), |
@@ -376,65 +378,65 @@ def transform_to_backend_pipeline( |
376 | 378 |
|
377 | 379 | def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
378 | 380 | # Preprocessing passes |
379 | | - self.add_pass(RemoveGraphAssertsPass()) |
| 381 | + self.add_pass(RemoveGraphAssertsPass(tfa_pass=True)) |
380 | 382 |
|
381 | 383 | # Transformation passes (pre scalar -> tensor) |
382 | 384 | self.add_passes( |
383 | 385 | [ |
384 | | - DecomposeSelectScatterPass(), |
385 | | - ConvertInt64ConstOpsToInt32Pass(), |
386 | | - ConvertInt64OutputOpsToInt32Pass(), |
387 | | - InsertInt32CastsAfterInt64PlaceholdersPass(), |
388 | | - DecomposeEmbeddingPass(), |
389 | | - DecomposeScaledDotProductAttentionPass(), |
390 | | - DecomposeRoundPass(), |
391 | | - DecomposeLogitPass(), |
392 | | - PromoteBoolOperandsPass(), |
393 | | - DecomposeSignPass(), |
394 | | - DecomposeAddmmPass(), |
395 | | - DecomposeRemainderPass(), |
396 | | - DecomposeFloorDividePass(), |
397 | | - DecomposeDivTensorModePass(), |
| 386 | + DecomposeSelectScatterPass(tfa_pass=True), |
| 387 | + ConvertInt64ConstOpsToInt32Pass(tfa_pass=True), |
| 388 | + ConvertInt64OutputOpsToInt32Pass(tfa_pass=True), |
| 389 | + InsertInt32CastsAfterInt64PlaceholdersPass(tfa_pass=True), |
| 390 | + DecomposeEmbeddingPass(tfa_pass=True), |
| 391 | + DecomposeScaledDotProductAttentionPass(tfa_pass=True), |
| 392 | + DecomposeRoundPass(tfa_pass=True), |
| 393 | + DecomposeLogitPass(tfa_pass=True), |
| 394 | + PromoteBoolOperandsPass(tfa_pass=True), |
| 395 | + DecomposeSignPass(tfa_pass=True), |
| 396 | + DecomposeAddmmPass(tfa_pass=True), |
| 397 | + DecomposeRemainderPass(tfa_pass=True), |
| 398 | + DecomposeFloorDividePass(tfa_pass=True), |
| 399 | + DecomposeDivTensorModePass(tfa_pass=True), |
398 | 400 | ] |
399 | 401 | ) |
400 | 402 |
|
401 | 403 | # Scalars -> tensors |
402 | 404 | self.add_passes( |
403 | 405 | [ |
404 | | - ReplaceScalarWithTensorByProfilePass(), |
405 | | - ScalarsToAttributePass(), |
| 406 | + ReplaceScalarWithTensorByProfilePass(tfa_pass=True), |
| 407 | + ScalarsToAttributePass(tfa_pass=True), |
406 | 408 | ] |
407 | 409 | ) |
408 | 410 |
|
409 | 411 | # Transformation passes (post scalar removal) |
410 | 412 | self.add_passes( |
411 | 413 | [ |
412 | | - NormalizeWhileInitialArgsPass(use_exir_clone=False), |
413 | | - DecomposeAddSubAlphaPass(), |
414 | | - DecomposeGroupNormPass(), |
415 | | - DecomposeLayerNormPass(), |
416 | | - DecomposeVarPass(), |
417 | | - DecomposeMeanDimPass(graph_module, self.tosa_spec), |
418 | | - DecomposeNotEqualPass(), |
419 | | - DecomposeCosineSimilarityPass(), |
420 | | - DecomposeGluPass(), |
421 | | - DecomposeDivPass(), |
422 | | - DecomposeLeakyReLUPass(), |
423 | | - DecomposeLinalgVectorNormPass(), |
424 | | - DecomposeSqrtPass(), |
425 | | - DecomposeSiluPass(), |
426 | | - DecomposeAvgPool2dPass(), |
427 | | - DecomposeSoftmaxUnstablePass(), |
428 | | - DecomposeSoftmaxPass(), |
429 | | - ConvertMinMaxPass(), |
| 414 | + NormalizeWhileInitialArgsPass(use_exir_clone=False, tfa_pass=True), |
| 415 | + DecomposeAddSubAlphaPass(tfa_pass=True), |
| 416 | + DecomposeGroupNormPass(tfa_pass=True), |
| 417 | + DecomposeLayerNormPass(tfa_pass=True), |
| 418 | + DecomposeVarPass(tfa_pass=True), |
| 419 | + DecomposeMeanDimPass(graph_module, self.tosa_spec, tfa_pass=True), |
| 420 | + DecomposeNotEqualPass(tfa_pass=True), |
| 421 | + DecomposeCosineSimilarityPass(tfa_pass=True), |
| 422 | + DecomposeGluPass(tfa_pass=True), |
| 423 | + DecomposeDivPass(tfa_pass=True), |
| 424 | + DecomposeLeakyReLUPass(tfa_pass=True), |
| 425 | + DecomposeLinalgVectorNormPass(tfa_pass=True), |
| 426 | + DecomposeSqrtPass(tfa_pass=True), |
| 427 | + DecomposeSiluPass(tfa_pass=True), |
| 428 | + DecomposeAvgPool2dPass(tfa_pass=True), |
| 429 | + DecomposeSoftmaxUnstablePass(tfa_pass=True), |
| 430 | + DecomposeSoftmaxPass(tfa_pass=True), |
| 431 | + ConvertMinMaxPass(tfa_pass=True), |
430 | 432 | ] |
431 | 433 | ) |
432 | 434 |
|
433 | 435 | # Postprocessing passes |
434 | 436 | self.add_passes( |
435 | 437 | [ |
436 | | - ReplaceInfAndLimitValuesPass(), |
437 | | - DecomposeMaskedFillPass(), |
| 438 | + ReplaceInfAndLimitValuesPass(tfa_pass=True), |
| 439 | + DecomposeMaskedFillPass(tfa_pass=True), |
438 | 440 | ] |
439 | 441 | ) |
440 | 442 |
|
|
0 commit comments