Skip to content

Commit f6ae9ac

Browse files
committed
AutoGemm KernelOpenCL can generate standalone kernels
1 parent e7e01ad commit f6ae9ac

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

src/library/blas/AutoGemm/KernelOpenCL.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import Common
55
import KernelParameters
66
import AutoGemmParameters
7+
import argparse
78

89

910
##############################################################################
@@ -541,9 +542,46 @@ def writeOpenCLKernels():
541542
# Main
542543
################################################################################
543544
if __name__ == "__main__":
544-
if len(sys.argv) == 2:
545-
Common.setOutputPath(sys.argv[1])
545+
ap = argparse.ArgumentParser(description="KernelOpenCL")
546+
ap.add_argument("precision", choices=["s","d","c","z"], help="precision" )
547+
ap.add_argument("order", choices=["row","col"], help="order: row major or column major" )
548+
ap.add_argument("transA", choices=["N","T", "C"], help="transA" )
549+
ap.add_argument("transB", choices=["N","T", "C"], help="transB" )
550+
ap.add_argument("beta", choices=[0, 1], type=int, help="0 for beta is zero, 1 for beta is non-zero" )
551+
ap.add_argument("workGroupNumRows", type=int )
552+
ap.add_argument("workGroupNumCols", type=int )
553+
ap.add_argument("microTileNumRows", type=int )
554+
ap.add_argument("microTileNumCols", type=int )
555+
ap.add_argument("unroll", type=int, help="number of iterations to unroll the loop over k" )
556+
ap.add_argument("outputPath", default=".", help="output path; %s will be appended to path" % Common.getRelativeKernelSourcePath() )
557+
558+
args = ap.parse_args()
559+
560+
kernel = KernelParameters.KernelParameters()
561+
kernel.precision = args.precision
562+
if args.order == "col":
563+
kernel.order = "clblasColumnMajor"
546564
else:
547-
print "Warning: No output path specified; default is working directory."
548-
writeOpenCLKernels()
565+
kernel.order = "clblasRowMajor"
566+
kernel.transA = args.transA
567+
kernel.transB = args.transB
568+
kernel.beta = args.beta
569+
kernel.workGroupNumRows = args.workGroupNumRows
570+
kernel.workGroupNumCols = args.workGroupNumCols
571+
kernel.microTileNumRows = args.microTileNumRows
572+
kernel.microTileNumCols = args.microTileNumCols
573+
kernel.unroll = args.unroll
574+
Common.setOutputPath(args.outputPath)
575+
576+
kernel.macroTileNumRows = kernel.workGroupNumRows * kernel.microTileNumRows
577+
kernel.macroTileNumCols = kernel.workGroupNumCols * kernel.microTileNumCols
578+
579+
if not os.path.exists( Common.getKernelSourcePath() ):
580+
os.makedirs( Common.getKernelSourcePath() )
581+
582+
writeOpenCLKernelToFile(kernel)
583+
584+
kernelName = kernel.getName()
585+
kernelFileName = Common.getKernelSourcePath() + kernelName +"_src.cpp"
586+
print "kernel \"%s\" written to %s" % (kernelName, kernelFileName)
549587

0 commit comments

Comments
 (0)