Skip to content

Commit a6143c5

Browse files
committed
More refactoring.
Now, all the veryhighrank tests pass and the others fail for an unknown reason.
1 parent fa5222c commit a6143c5

File tree

3 files changed

+2379
-1609
lines changed

3 files changed

+2379
-1609
lines changed

src/gpuarray/reduction.h

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,68 +22,98 @@ extern "C" {
2222

2323

2424
/* Data Structures */
25+
struct GpuReductionAttr;
2526
struct GpuReduction;
26-
typedef struct GpuReduction GpuReduction;
27+
typedef struct GpuReductionAttr GpuReductionAttr;
28+
typedef struct GpuReduction GpuReduction;
2729

2830

2931
/**
3032
* Supported array reduction operations.
3133
*/
3234

3335
typedef enum _ga_reduce_op {
34-
/* dst , dstArg */
35-
GA_REDUCE_SUM, /* + */
36-
GA_REDUCE_PROD, /* * */
37-
GA_REDUCE_PRODNZ, /* * (!=0) */
38-
GA_REDUCE_MIN, /* min() */
39-
GA_REDUCE_MAX, /* max() */
40-
GA_REDUCE_ARGMIN, /* argmin() */
41-
GA_REDUCE_ARGMAX, /* argmax() */
42-
GA_REDUCE_MINANDARGMIN, /* min() , argmin() */
43-
GA_REDUCE_MAXANDARGMAX, /* max() , argmax() */
44-
GA_REDUCE_AND, /* & */
45-
GA_REDUCE_OR, /* | */
46-
GA_REDUCE_XOR, /* ^ */
47-
GA_REDUCE_ALL, /* &&/all() */
48-
GA_REDUCE_ANY, /* ||/any() */
36+
/* d0 , d1 */
37+
GA_ELEMWISE,
38+
GA_REDUCE_COPY=GA_ELEMWISE, /* (copy) */
39+
GA_REDUCE_SUM, /* + */
40+
GA_REDUCE_PROD, /* * */
41+
GA_REDUCE_PRODNZ, /* * (!=0) */
42+
GA_REDUCE_MIN, /* min() */
43+
GA_REDUCE_MAX, /* max() */
44+
GA_REDUCE_ARGMIN, /* argmin() */
45+
GA_REDUCE_ARGMAX, /* argmax() */
46+
GA_REDUCE_MINANDARGMIN, /* min() , argmin() */
47+
GA_REDUCE_MAXANDARGMAX, /* max() , argmax() */
48+
GA_REDUCE_AND, /* & */
49+
GA_REDUCE_OR, /* | */
50+
GA_REDUCE_XOR, /* ^ */
51+
GA_REDUCE_ALL, /* &&/all() */
52+
GA_REDUCE_ANY, /* ||/any() */
4953

50-
GA_REDUCE_ENDSUPPORTED /* Must be last element in enum */
54+
GA_REDUCE_ENDSUPPORTED /* Must be last element in enum */
5155
} ga_reduce_op;
5256

5357

5458
/* External Functions */
5559

5660
/**
57-
* @brief Create a new GPU reduction operator over a list of axes to reduce.
61+
* @brief Create, modify and free the attributes of a reduction operator.
62+
*
63+
* @param [out] grAttr The reduction operator attributes object.
64+
* @param [in] op The reduction operation.
65+
* @param [in] maxSrcDims The maximum number of supported source dimensions.
66+
* @param [in] maxDstDims The maximum number of supported destination dimensions.
67+
* @param [in] s0Typecode The typecode of the source tensor.
68+
* @param [in] d0Typecode The typecode of the first destination tensor.
69+
* @param [in] d1Typecode The typecode of the second destination tensor.
70+
* @param [in] i0Typecode The typecode of the indices.
71+
*/
72+
73+
GPUARRAY_PUBLIC int GpuReductionAttr_new (GpuReductionAttr** grAttr,
74+
gpucontext* gpuCtx);
75+
GPUARRAY_PUBLIC int GpuReductionAttr_setop (GpuReductionAttr* grAttr,
76+
ga_reduce_op op);
77+
GPUARRAY_PUBLIC int GpuReductionAttr_setdims (GpuReductionAttr* grAttr,
78+
unsigned maxSrcDims,
79+
unsigned maxDstDims);
80+
GPUARRAY_PUBLIC int GpuReductionAttr_sets0type (GpuReductionAttr* grAttr,
81+
int s0Typecode);
82+
GPUARRAY_PUBLIC int GpuReductionAttr_setd0type (GpuReductionAttr* grAttr,
83+
int d0Typecode);
84+
GPUARRAY_PUBLIC int GpuReductionAttr_setd1type (GpuReductionAttr* grAttr,
85+
int d1Typecode);
86+
GPUARRAY_PUBLIC int GpuReductionAttr_seti0type (GpuReductionAttr* grAttr,
87+
int i0Typecode);
88+
GPUARRAY_PUBLIC int GpuReductionAttr_appendopname (GpuReductionAttr* grAttr,
89+
size_t n,
90+
char* name);
91+
GPUARRAY_PUBLIC int GpuReductionAttr_issensitive (const GpuReductionAttr* grAttr);
92+
GPUARRAY_PUBLIC int GpuReductionAttr_requiresS0 (const GpuReductionAttr* grAttr);
93+
GPUARRAY_PUBLIC int GpuReductionAttr_requiresD0 (const GpuReductionAttr* grAttr);
94+
GPUARRAY_PUBLIC int GpuReductionAttr_requiresD1 (const GpuReductionAttr* grAttr);
95+
GPUARRAY_PUBLIC void GpuReductionAttr_free (GpuReductionAttr* grAttr);
96+
97+
/**
98+
* @brief Create a new GPU reduction operator with the given attributes.
5899
*
59100
* @param [out] gr The reduction operator.
60-
* @param [in] gpuCtx The GPU context.
61-
* @param [in] op The reduction operation to perform.
62-
* @param [in] ndf The minimum number of free (destination) dimensions to support.
63-
* @param [in] ndr The minimum number of reduction (source) dimensions to support.
64-
* @param [in] s0TypeCode The data type of the source operand.
65-
* @param [in] flags Reduction operator creation flags. Currently must be
66-
* set to 0.
101+
* @param [in] grAttr The GPU context.
67102
*
68103
* @return GA_NO_ERROR if the operator was created successfully
69-
* GA_INVALID_ERROR if grOut is NULL, or some other argument was invalid
104+
* GA_INVALID_ERROR if some argument was invalid
70105
* GA_NO_MEMORY if memory allocation failed anytime during creation
71106
* or other non-zero error codes otherwise.
72107
*/
73108

74-
GPUARRAY_PUBLIC int GpuReduction_new (GpuReduction** grOut,
75-
gpucontext* gpuCtx,
76-
ga_reduce_op op,
77-
unsigned ndf,
78-
unsigned ndr,
79-
int s0TypeCode,
80-
int flags);
109+
GPUARRAY_PUBLIC int GpuReduction_new (GpuReduction** gr,
110+
const GpuReductionAttr* grAttr);
81111

82112
/**
83113
* @brief Deallocate an operator allocated by GpuReduction_new().
84114
*/
85115

86-
GPUARRAY_PUBLIC void GpuReduction_free (GpuReduction* gr);
116+
GPUARRAY_PUBLIC void GpuReduction_free (GpuReduction* gr);
87117

88118
/**
89119
* @brief Invoke an operator allocated by GpuReduction_new() on a source tensor.
@@ -123,13 +153,13 @@ GPUARRAY_PUBLIC void GpuReduction_free (GpuReduction* gr);
123153
* error code otherwise.
124154
*/
125155

126-
GPUARRAY_PUBLIC int GpuReduction_call (const GpuReduction* gr,
127-
GpuArray* d0,
128-
GpuArray* d1,
129-
const GpuArray* s0,
130-
unsigned reduxLen,
131-
const int* reduxList,
132-
int flags);
156+
GPUARRAY_PUBLIC int GpuReduction_call (const GpuReduction* gr,
157+
GpuArray* d0,
158+
GpuArray* d1,
159+
const GpuArray* s0,
160+
unsigned reduxLen,
161+
const int* reduxList,
162+
int flags);
133163

134164

135165
#ifdef __cplusplus

0 commit comments

Comments
 (0)