Skip to content

Commit a1dc797

Browse files
committed
Add SoftmaxCrossEntropyLoss(Grad)
1 parent 4540231 commit a1dc797

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,18 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
694694
attrDescriptors = [AttrDesc("lr", FloatUnpack)],
695695
)
696696

697+
softmaxCrossEntropyLossDesc = OperatorDescriptor(
698+
inputDescriptor = IoDesc(["logits", "labels"]),
699+
outputDescriptor = IoDesc("log_prob"),
700+
attrDescriptors = [],
701+
)
702+
703+
softmaxCrossEntropyLossGradDesc = OperatorDescriptor(
704+
inputDescriptor = IoDesc(["log_prob", "labels"]),
705+
outputDescriptor = IoDesc("grad"),
706+
attrDescriptors = [],
707+
)
708+
697709
defaultOperatorDescriptors: Dict[str, OperatorDescriptor] = {
698710
"Add": addDesc,
699711
"CLCA": clcaDesc,
@@ -733,6 +745,8 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
733745
"SGD": sgdDesc,
734746
"Slice": sliceDesc,
735747
"Softmax": softmaxDesc,
748+
"SoftmaxCrossEntropyLoss": softmaxCrossEntropyLossDesc,
749+
"SoftmaxCrossEntropyLossGrad": softmaxCrossEntropyLossGradDesc,
736750
"SoftmaxGrad": softmaxGradDesc,
737751
"Squeeze": squeezeDesc,
738752
"Transpose": transposeDesc,

0 commit comments

Comments
 (0)