Skip to content

Commit 05df33b

Browse files
authored
Add missing documentation for fused kernels (#744)
1 parent 2f9595a commit 05df33b

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

operators/cuda/negxplus1.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace contrib {
1010

11+
/**
12+
* NegXPlus1(X) = 1 - X
13+
*/
1114
template <typename T>
1215
struct NegXPlus1 {
1316
template <typename TDict>

operators/cuda/scatter_nd_of_shape.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace contrib {
1010

11+
/**
12+
* ScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), indices, updates)
13+
*/
1114
template <typename T>
1215
struct ScatterNDOfShape {
1316
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
@@ -71,6 +74,12 @@ struct ScatterNDOfShape {
7174
};
7275

7376

77+
/**
78+
* MaskedScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0),
79+
* indices[indices != maskedValue],
80+
* updates[indices != maskedValue])
81+
*
82+
*/
7483
template <typename T>
7584
struct MaskedScatterNDOfShape {
7685
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {

operators/cuda/transpose_cast.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace contrib {
1010

11+
/**
12+
* Transpose2DCast(X, to=to) = Cast(Transpose(X, perm=[1, 0]), to=to)
13+
*/
1114
template <typename TIN, typename TOUT>
1215
struct Transpose2DCast {
1316
template <typename TDict>

0 commit comments

Comments
 (0)