Skip to content
Merged
1 change: 0 additions & 1 deletion backends/xnnpack/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
op_quant_dequant,
op_relu,
op_rsqrt,
op_sdpa,
op_sigmoid,
op_skip_ops,
op_slice_copy,
Expand Down
111 changes: 0 additions & 111 deletions backends/xnnpack/operators/op_sdpa.py

This file was deleted.

2 changes: 0 additions & 2 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
QuantizedPerTensorConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down Expand Up @@ -99,7 +98,6 @@
PreluConfig,
ReciprocalSquareRootConfig,
ReLUConfig,
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down
30 changes: 0 additions & 30 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,33 +527,3 @@ class BMMConfig(GenericNodePartitionerConfig):

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class SDPAConfig(GenericNodePartitionerConfig):
target_name = "scaled_dot_product_attention.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Requires Mask to have Rank 2
"""
if not self.check_common_constraints(node, ep):
return False

if len(node.all_input_nodes) < 4:
return False
mask_node = node.all_input_nodes[3]
mask_rank = mask_node.meta["val"].dim()
if mask_rank != 2:
why(
node,
reason=f"mask must have rank 2, got mask of rank {mask_rank}",
)
return False

return True

def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
return torch.ops.aten.scaled_dot_product_attention.default

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]
37 changes: 0 additions & 37 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1961,42 +1961,6 @@ Error defineStaticSliceNode(
return Error::Ok;
}

/*
Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
using the remapped ids to map the serialized ids,
to the new ids generated when defining the tensor value
*/
Error defineScaledDotProductAttentionNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);

auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();

xnn_status status = xnn_define_scaled_dot_product_attention(
subgraph_ptr,
xnn_attention_logits_cap_type_none, // cap_type
nullptr, // cap_value - not used
remapped_ids.at(graph_node->query_id()),
remapped_ids.at(graph_node->key_id()),
remapped_ids.at(graph_node->value_id()),
remapped_ids.at(graph_node->scale_id()),
remapped_ids.at(graph_node->mask_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());

ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create SDPA node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));

return Error::Ok;
}

/*
Defines batch matrix multiply node into the subgraph,
using the remapped ids to map the serialized ids,
Expand Down Expand Up @@ -2097,7 +2061,6 @@ DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
_DEFINE(Concatenate4)
_DEFINE(Concatenate5)
_DEFINE(StaticSlice)
_DEFINE(ScaledDotProductAttention)
_DEFINE(BatchMatrixMultiply)
case fb_xnnpack::XNodeUnion::NONE:
default: // Adding here as a catch all, just in case
Expand Down
130 changes: 0 additions & 130 deletions backends/xnnpack/test/ops/test_sdpa.py

This file was deleted.

Loading