Skip to content

Commit 0ada211

Browse files
Merge pull request #35 from NumPower/feat/update_package_1
Broadcasting fixes and updates
2 parents 89528a6 + 544adeb commit 0ada211

File tree

6 files changed

+74
-58
lines changed

6 files changed

+74
-58
lines changed

numpower.c

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ PHP_METHOD(NDArray, normal) {
809809
if (nda == NULL) return;
810810
shape = emalloc(sizeof(int) * NDArray_NUMELEMENTS(nda));
811811
for (int i = 0; i < NDArray_NUMELEMENTS(nda); i++) {
812-
shape[i] = (int) NDArray_DDATA(nda)[i];
812+
shape[i] = (int) NDArray_FDATA(nda)[i];
813813
}
814814
rtn = NDArray_Normal(loc, scale, shape, NDArray_NUMELEMENTS(nda));
815815
NDArray_FREE(nda);
@@ -1260,36 +1260,23 @@ PHP_METHOD(NDArray, atleast_1d) {
12601260
* @param return_value
12611261
*/
12621262
ZEND_BEGIN_ARG_INFO_EX(arginfo_ndarray_atleast_2d, 0, 0, 1)
1263-
ZEND_ARG_INFO(0, array)
1264-
ZEND_ARG_INFO(0, axis)
1263+
ZEND_ARG_INFO(0, a)
12651264
ZEND_END_ARG_INFO()
12661265
PHP_METHOD(NDArray, atleast_2d) {
12671266
NDArray *rtn = NULL;
1268-
zval *array;
1267+
zval *a;
12691268
long axis;
12701269
int axis_i;
12711270
ZEND_PARSE_PARAMETERS_START(1, 1)
1272-
Z_PARAM_ZVAL(array)
1273-
Z_PARAM_OPTIONAL
1274-
Z_PARAM_LONG(axis)
1271+
Z_PARAM_ZVAL(a)
12751272
ZEND_PARSE_PARAMETERS_END();
1276-
NDArray *nda = ZVAL_TO_NDARRAY(array);
1273+
NDArray *nda = ZVAL_TO_NDARRAY(a);
12771274
if (nda == NULL) {
12781275
return;
12791276
}
1280-
axis_i = (int)axis;
1281-
if (ZEND_NUM_ARGS() == 1) {
1282-
rtn = NDArray_Transpose(nda, NULL);
1283-
add_to_buffer(rtn, sizeof(NDArray));
1284-
RETURN_NDARRAY(rtn, return_value);
1285-
} else {
1286-
if (NDArray_DEVICE(nda) == NDARRAY_DEVICE_GPU) {
1287-
zend_throw_error(NULL, "Axis not supported for GPU operation");
1288-
return;
1289-
}
1290-
zend_throw_error(NULL, "Not implemented");
1291-
return;
1292-
}
1277+
//rtn = NDArray_AtLeast2D(nda);
1278+
CHECK_INPUT_AND_FREE(a, nda);
1279+
RETURN_NDARRAY(rtn, return_value);
12931280
}
12941281

12951282
/**
@@ -2843,9 +2830,8 @@ PHP_METHOD(NDArray, subtract) {
28432830
CHECK_INPUT_AND_FREE(a, nda);
28442831
return;
28452832
}
2846-
if (!NDArray_ShapeCompare(nda, ndb)) {
2847-
zend_throw_error(NULL, "Incompatible shapes");
2848-
return;
2833+
if (!NDArray_IsBroadcastable(nda, ndb)) {
2834+
zend_throw_error(NULL, "Can´t broadcast array.");
28492835
}
28502836
rtn = NDArray_Subtract_Float(nda, ndb);
28512837
CHECK_INPUT_AND_FREE(a, nda);
@@ -2877,9 +2863,8 @@ PHP_METHOD(NDArray, mod) {
28772863
CHECK_INPUT_AND_FREE(a, nda);
28782864
return;
28792865
}
2880-
if (!NDArray_ShapeCompare(nda, ndb)) {
2881-
zend_throw_error(NULL, "Incompatible shapes");
2882-
return;
2866+
if (!NDArray_IsBroadcastable(nda, ndb)) {
2867+
zend_throw_error(NULL, "Can´t broadcast array.");
28832868
}
28842869
rtn = NDArray_Mod_Float(nda, ndb);
28852870
CHECK_INPUT_AND_FREE(a, nda);
@@ -2945,9 +2930,8 @@ PHP_METHOD(NDArray, multiply) {
29452930
CHECK_INPUT_AND_FREE(a, nda);
29462931
return;
29472932
}
2948-
if (!NDArray_ShapeCompare(nda, ndb)) {
2949-
zend_throw_error(NULL, "Incompatible shapes");
2950-
return;
2933+
if (!NDArray_IsBroadcastable(nda, ndb)) {
2934+
zend_throw_error(NULL, "Can´t broadcast array.");
29512935
}
29522936
rtn = NDArray_Multiply_Float(nda, ndb);
29532937

@@ -2979,9 +2963,8 @@ PHP_METHOD(NDArray, divide) {
29792963
if (ndb == NULL) {
29802964
return;
29812965
}
2982-
if (!NDArray_ShapeCompare(nda, ndb)) {
2983-
zend_throw_error(NULL, "Incompatible shapes");
2984-
return;
2966+
if (!NDArray_IsBroadcastable(nda, ndb)) {
2967+
zend_throw_error(NULL, "Can´t broadcast array.");
29852968
}
29862969
rtn = NDArray_Divide_Float(nda, ndb);
29872970
CHECK_INPUT_AND_FREE(a, nda);

src/debug.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ NDArray_Dump(NDArray* array) {
3131
}
3232
printf(" ]\n");
3333
if (NDArray_DEVICE(array) == NDARRAY_DEVICE_GPU) {
34-
printf("NDArray.device\t\t\t%s\n", "GPU");
34+
printf("NDArray.device\t\t\t(%d) %s\n", NDArray_DEVICE(array), "GPU");
35+
} else if(NDArray_DEVICE(array) == NDARRAY_DEVICE_CPU) {
36+
printf("NDArray.device\t\t\t(%d) %s\n", NDArray_DEVICE(array), "CPU");
3537
} else {
36-
printf("NDArray.device\t\t\t%s\n", "CPU");
38+
printf("NDArray.device\t\t\t(%d) %s\n", NDArray_DEVICE(array), "ERROR");
3739
}
3840
printf("NDArray.refcount\t\t%d\n", array->refcount);
3941
printf("NDArray.descriptor.elsize\t%d\n", array->descriptor->elsize);

src/manipulation.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,4 @@ NDArray_Append(NDArray *a, NDArray *b) {
326326
}
327327

328328
return rtn;
329-
}
329+
}

src/ndarray.c

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,13 @@ NDArray_Broadcast(NDArray *a, NDArray *b) {
10761076
}
10771077
rtn = NDArray_Copy(dst, NDArray_DEVICE(dst));
10781078
char *rtn_p = NDArray_DATA(rtn);
1079+
1080+
if (NDArray_NDIM(a) == 0 && NDArray_NDIM(b) > 0) {
1081+
for (i = 0; i < NDArray_NUMELEMENTS(b); i++) {
1082+
NDArray_FDATA(rtn)[i] = NDArray_FDATA(a)[0];
1083+
}
1084+
}
1085+
10791086
if (NDArray_NDIM(src) == 1 && NDArray_NDIM(dst) > 1) {
10801087
if (NDArray_SHAPE(src)[0] == NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]) {
10811088
if (NDArray_DEVICE(dst) == NDARRAY_DEVICE_CPU) {
@@ -1100,20 +1107,42 @@ NDArray_Broadcast(NDArray *a, NDArray *b) {
11001107
if (NDArray_NDIM(src) == 2 && NDArray_NDIM(dst) == 2) {
11011108
if (NDArray_SHAPE(src)[NDArray_NDIM(dst) - 2] == NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]) {
11021109
if (NDArray_DEVICE(dst) == NDARRAY_DEVICE_CPU) {
1103-
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1104-
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1105-
NDArray_FDATA(rtn)[(i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2]/ NDArray_ELSIZE(rtn))+j] = NDArray_FDATA(src)[i];
1110+
if (NDArray_NUMELEMENTS(src) != 1) {
1111+
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1112+
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1113+
NDArray_FDATA(rtn)[(i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2] / NDArray_ELSIZE(rtn)) +
1114+
j] = NDArray_FDATA(src)[i];
1115+
}
1116+
}
1117+
} else {
1118+
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1119+
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1120+
NDArray_FDATA(rtn)[(i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2] / NDArray_ELSIZE(rtn)) +
1121+
j] = NDArray_FDATA(src)[0];
1122+
}
11061123
}
1124+
return rtn;
11071125
}
11081126
}
11091127
#ifdef HAVE_CUBLAS
11101128
if (NDArray_DEVICE(dst) == NDARRAY_DEVICE_GPU) {
1111-
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1112-
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1113-
tmp_p = (char*)(NDArray_FDATA(src) + i);
1114-
rtn_p = (char*)(NDArray_FDATA(rtn) + (i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2]/ NDArray_ELSIZE(rtn))+j);
1115-
NDArray_VMEMCPY_D2D(tmp_p, rtn_p, sizeof(float));
1129+
if (NDArray_NUMELEMENTS(src) != 1) {
1130+
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1131+
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1132+
tmp_p = (char*)(NDArray_FDATA(src) + i);
1133+
rtn_p = (char*)(NDArray_FDATA(rtn) + (i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2]/ NDArray_ELSIZE(rtn))+j);
1134+
NDArray_VMEMCPY_D2D(tmp_p, rtn_p, sizeof(float));
1135+
}
1136+
}
1137+
} else {
1138+
for (i = 0; i < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 2]; i++) {
1139+
for (j = 0; j < NDArray_SHAPE(dst)[NDArray_NDIM(dst) - 1]; j++) {
1140+
tmp_p = (char*)(NDArray_FDATA(src));
1141+
rtn_p = (char*)(NDArray_FDATA(rtn) + (i * NDArray_STRIDES(rtn)[NDArray_NDIM(rtn) - 2]/ NDArray_ELSIZE(rtn))+j);
1142+
NDArray_VMEMCPY_D2D(tmp_p, rtn_p, sizeof(float));
1143+
}
11161144
}
1145+
return rtn;
11171146
}
11181147
}
11191148
#endif

src/ndmath/arithmetics.c

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,12 @@ NDArray_Median_Float(NDArray* a) {
185185

186186
NDArray*
187187
NDArray_Add_Float(NDArray* a, NDArray* b) {
188-
NDArray *broadcasted = NULL;
189188
if (NDArray_DEVICE(a) != NDArray_DEVICE(b)) {
190189
zend_throw_error(NULL, "Device mismatch, both NDArray MUST be in the same device.");
191190
return NULL;
192191
}
193192

194-
NDArray *a_broad = NULL, *b_broad = NULL;
195-
196-
if (NDArray_NDIM(a) == 0) {
193+
if (NDArray_NDIM(a) == 0 && NDArray_NDIM(b) == 0) {
197194
int* shape = ecalloc(1, sizeof(int));
198195
NDArray *rtn = NDArray_Zeros(shape, 0, NDARRAY_TYPE_FLOAT32, NDArray_DEVICE(a));
199196
#ifdef HAVE_CUBLAS
@@ -208,6 +205,9 @@ NDArray_Add_Float(NDArray* a, NDArray* b) {
208205
return rtn;
209206
}
210207

208+
NDArray *broadcasted = NULL;
209+
NDArray *a_broad = NULL, *b_broad = NULL;
210+
211211
if (NDArray_NUMELEMENTS(a) < NDArray_NUMELEMENTS(b)) {
212212
broadcasted = NDArray_Broadcast(a, b);
213213
a_broad = broadcasted;
@@ -250,6 +250,7 @@ NDArray_Add_Float(NDArray* a, NDArray* b) {
250250
result->descriptor = (NDArrayDescriptor*)emalloc(sizeof(NDArrayDescriptor));
251251
result->descriptor->type = NDARRAY_TYPE_FLOAT32;
252252
result->descriptor->elsize = sizeof(float);
253+
result->device = NDArray_DEVICE(a_broad);
253254
result->descriptor->numElements = a_broad->descriptor->numElements;
254255
result->refcount = 1;
255256

@@ -278,7 +279,7 @@ NDArray_Add_Float(NDArray* a, NDArray* b) {
278279
_mm256_storeu_ps(&resultData[i], mul);
279280
}
280281
// Handle remaining elements if the length is not a multiple of 4
281-
for (; i < NDArray_NUMELEMENTS(a); i++) {
282+
for (; i < numElements; i++) {
282283
resultData[i] = aData[i] + bData[i];
283284
}
284285
#elif HAVE_CBLAS
@@ -365,7 +366,6 @@ NDArray_Multiply_Float(NDArray* a, NDArray* b) {
365366
b_broad = b;
366367
a_broad = a;
367368
}
368-
369369
if (b_broad == NULL || a_broad == NULL) {
370370
zend_throw_error(NULL, "Can't broadcast arrays.");
371371
return NULL;
@@ -415,18 +415,18 @@ NDArray_Multiply_Float(NDArray* a, NDArray* b) {
415415
#endif
416416
} else {
417417
#ifdef HAVE_AVX2
418-
int i;
418+
int i = 0;
419419
__m256 vec1, vec2, mul;
420420

421-
for (i = 0; i < NDArray_NUMELEMENTS(a) - 7; i += 8) {
421+
for (; i < NDArray_NUMELEMENTS(a) - 7; i += 8) {
422422
vec1 = _mm256_loadu_ps(&aData[i]);
423423
vec2 = _mm256_loadu_ps(&bData[i]);
424424
mul = _mm256_mul_ps(vec1, vec2);
425425
_mm256_storeu_ps(&resultData[i], mul);
426426
}
427427

428428
// Handle remaining elements if the length is not a multiple of 4
429-
for (; i < NDArray_NUMELEMENTS(a); i++) {
429+
for (; i < numElements; i++) {
430430
resultData[i] = aData[i] * bData[i];
431431
}
432432
#else
@@ -553,7 +553,7 @@ NDArray_Subtract_Float(NDArray* a, NDArray* b) {
553553
}
554554

555555
// Handle remaining elements if the length is not a multiple of 4
556-
for (; i < NDArray_NUMELEMENTS(a); i++) {
556+
for (; i < numElements; i++) {
557557
resultData[i] = aData[i] - bData[i];
558558
}
559559
#else
@@ -584,6 +584,7 @@ NDArray_Subtract_Float(NDArray* a, NDArray* b) {
584584
NDArray*
585585
NDArray_Divide_Float(NDArray* a, NDArray* b) {
586586
NDArray *a_temp = NULL, *b_temp = NULL;
587+
587588
if (NDArray_DEVICE(a) != NDArray_DEVICE(b)) {
588589
zend_throw_error(NULL, "Device mismatch, both NDArray MUST be in the same device.");
589590
return NULL;
@@ -658,6 +659,7 @@ NDArray_Divide_Float(NDArray* a, NDArray* b) {
658659
result->descriptor = (NDArrayDescriptor *) emalloc(sizeof(NDArrayDescriptor));
659660
result->descriptor->type = NDARRAY_TYPE_FLOAT32;
660661
result->descriptor->elsize = sizeof(float);
662+
result->device = NDArray_DEVICE(a_broad);
661663
result->descriptor->numElements = a_broad->descriptor->numElements;
662664
result->refcount = 1;
663665

@@ -687,7 +689,7 @@ NDArray_Divide_Float(NDArray* a, NDArray* b) {
687689
}
688690

689691
// Handle remaining elements if the length is not a multiple of 4
690-
for (; i < NDArray_NUMELEMENTS(a); i++) {
692+
for (; i < numElements; i++) {
691693
resultData[i] = aData[i] / bData[i];
692694
}
693695
#else

tests/003-ndarray-add.phpt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ Array
1313
(
1414
[0] => Array
1515
(
16-
[0] => 2
16+
[0] => 3
1717
[1] => 4
1818
)
1919

2020
[1] => Array
2121
(
22-
[0] => 6
23-
[1] => 8
22+
[0] => 5
23+
[1] => 6
2424
)
2525

2626
)

0 commit comments

Comments
 (0)