Skip to content

Commit 7723845

Browse files
author
Kent Knox
committed
Merge pull request #41 from TimmyLiu/develop
a couple bug fixes related to c/z syr2k
2 parents d2cf285 + 42284c0 commit 7723845

File tree

4 files changed

+113
-7
lines changed

4 files changed

+113
-7
lines changed

src/library/blas/gens/blas_kgen.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "tile.h"
5353
#include "fetch.h"
5454

55+
5556
#define BLAS_KGEN_FORMAT 1
5657

5758
#define genInternalLoopEnd(ctx) kgenEndBranch(ctx, NULL)
@@ -539,6 +540,18 @@ sprintfComplexMulUpdate(
539540
bool conjB,
540541
TileMulCore core);
541542

543+
void
544+
sprintfComplexMulUpdate_syr2k_beta0(
545+
Kstring *expr,
546+
const Kstring *dst,
547+
const Kstring *a,
548+
const Kstring *b,
549+
const Kstring *c,
550+
bool isDouble,
551+
bool conjA,
552+
bool conjB,
553+
TileMulCore core);
554+
542555
/**
543556
* @brief Sprintf expression of fast scalar mad
544557
*
@@ -892,4 +905,6 @@ checkGenRestoreTailCoords(
892905
UpdateResultFlags
893906
tailStatusToUpresFlags(TailStatus status);
894907

908+
909+
895910
#endif /* BLAS_KGEN_H_ */

src/library/blas/gens/syrxk.c

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,19 @@ genUpdateSingleOptimized(
876876
sprintfComplexMulUpdate(&expr, k3, tempC, &betaStr, NULL,
877877
isDouble, false, false, core);
878878
kgenAddStmtToBatch(batch, MAD_STMT_PRIO, expr.buf);
879+
880+
sprintfComplexMulUpdate(&expr, tempC, result, &alphaStr, k3,
881+
isDouble, false, false, core);
882+
kgenAddStmtToBatch(batch, MAD_STMT_PRIO, expr.buf);
879883
}
884+
else
885+
{
886+
//fix correctness bug for c/z syr2k when beta = (0,0)
887+
sprintfComplexMulUpdate_syr2k_beta0(&expr, tempC, result, &alphaStr, NULL,
888+
isDouble, false, false, core);
889+
kgenAddStmtToBatch(batch, MAD_STMT_PRIO, expr.buf);
890+
}
880891

881-
sprintfComplexMulUpdate(&expr, tempC, result, &alphaStr, k3,
882-
isDouble, false, false, core);
883-
kgenAddStmtToBatch(batch, MAD_STMT_PRIO, expr.buf);
884892
}
885893
else {
886894
if (betaName != NULL) {
@@ -1171,7 +1179,6 @@ genUpdateIsoscelesDiagTile(
11711179
if (nrStored) {
11721180
sprintfTileElement(&tempElem, &tileTempC, iter.row % tempRows,
11731181
iter.col % tempCols, nrStored);
1174-
11751182
kgenBatchPrintf(batch, STORE_STMT_PRIO,
11761183
"*(__global %s*)(&%s[%s]) = %s;\n",
11771184
glbType, dstPtr, offExpr.buf, tempElem.buf);
@@ -1720,7 +1727,6 @@ genUpdateResult(
17201727
// the function above put a respective code into a conditional path
17211728
kgenBeginBranch(ctx, "else");
17221729
}
1723-
17241730
ret = genResultUpdateWithFlags( ctx,
17251731
funcID,
17261732
gset,

src/library/blas/gens/tilemul.c

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#include "blas_kgen.h"
3030

31+
3132
#define MAX_LENGTH 4096
3233
#define BITS_INT (sizeof(int) * 8)
3334

@@ -693,6 +694,90 @@ sprintfComplexMulUpdate(
693694
}
694695
}
695696

697+
void
698+
sprintfComplexMulUpdate_syr2k_beta0(
699+
Kstring *expr,
700+
const Kstring *dst,
701+
const Kstring *a,
702+
const Kstring *b,
703+
const Kstring *c,
704+
bool isDouble,
705+
bool conjA,
706+
bool conjB,
707+
TileMulCore core)
708+
{
709+
Kstring swSrc1; // swapped element of the first source
710+
// real and imaginary part of the second source
711+
Kstring reSrc2, imSrc2;
712+
const Kstring *src11, *src12, *src21, *src22;
713+
const char *sign1 = "", *sign2 = "", *sign3 = "";
714+
const char *baseType;
715+
716+
baseType = (isDouble) ? "double2" : "float2";
717+
718+
/*
719+
* Prepare components for multiplying. We should get the following
720+
* vectorized operations:
721+
*
722+
* c = b * a1 + bsw * (-a2, a2) if both 'a' and 'b' are not conjugated
723+
* c = b * a1 + bsw * (a2, -a2) if 'b' is conjugated and 'a' is not
724+
* c = a * b1 + asw * (-b2, b2) if 'a' is conjugated and 'b' is not
725+
* c = asw * (-b2) + a * (b1, -b1) if both 'a' and 'b' are conjugated
726+
*
727+
* Where (a1, a2) and (b1, b2) are complex components of 'a' and 'b',
728+
* and asw and bsw - swapped elements of 'a' and 'b' respectively.
729+
*/
730+
731+
src11 = (conjB) ? a : b;
732+
src21 = (conjB) ? b : a;
733+
734+
kstrcpy(&swSrc1, src11->buf);
735+
swapComplexComponents(&swSrc1, 1);
736+
takeComplexApart(&reSrc2, &imSrc2, src21);
737+
738+
if (conjA && conjB) {
739+
src12 = src11;
740+
src11 = &swSrc1;
741+
src21 = &imSrc2;
742+
src22 = &reSrc2;
743+
sign1 = sign3 = "-";
744+
}
745+
else {
746+
src12 = &swSrc1;
747+
src21 = &reSrc2;
748+
src22 = &imSrc2;
749+
if (conjA || conjB) {
750+
sign3 = "-";
751+
}
752+
else {
753+
sign2 = "-";
754+
}
755+
}
756+
757+
if (core == TILEMUL_MAD) {
758+
const char *strC = (c == NULL) ? "0" : c->buf;
759+
760+
ksprintf(expr, "%s = mad(%s, %s%s, %s);\n"
761+
"%s = mad(%s, (%s)(%s%s, %s%s), %s);\n",
762+
"sctmp", src11->buf, sign1, src21->buf, strC,
763+
dst->buf, src12->buf, baseType, sign2, src22->buf,
764+
sign3, src22->buf, "sctmp");
765+
}
766+
else {
767+
const char *op = (dst == c) ? "+=" : "=";
768+
769+
ksprintf(expr, "%s %s %s * %s%s + %s * (%s)(%s%s, %s%s)",
770+
dst->buf, op, src11->buf, sign1,
771+
src21->buf, src12->buf, baseType, sign2, src22->buf,
772+
sign3, src22->buf);
773+
if (!((c == NULL) || (c == dst))) {
774+
kstrcatf(expr, " + %s", c->buf);
775+
}
776+
kstrcatf(expr, "%s", ";\n");
777+
}
778+
}
779+
780+
696781
int
697782
genMulTiles(
698783
struct KgenContext *ctx,

src/library/tools/ktest/naive/naive_blas.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,13 @@ operator/(FloatComplex a, cl_float b)
350350
static __inline DoubleComplex
351351
operator+(DoubleComplex a, DoubleComplex b)
352352
{
353-
return doubleComplex(CREAL(a) + CREAL(b), CIMAG(b) + CIMAG(b));
353+
return doubleComplex(CREAL(a) + CREAL(b), CIMAG(a) + CIMAG(b));
354354
}
355355

356356
static __inline DoubleComplex
357357
operator-(DoubleComplex a, DoubleComplex b)
358358
{
359-
return doubleComplex(CREAL(a) - CREAL(b), CIMAG(b) - CIMAG(b));
359+
return doubleComplex(CREAL(a) - CREAL(b), CIMAG(a) - CIMAG(b));
360360
}
361361

362362
static __inline DoubleComplex

0 commit comments

Comments
 (0)