|
4 | 4 | import Common
|
5 | 5 | import KernelParameters
|
6 | 6 | import AutoGemmParameters
|
| 7 | +import argparse |
7 | 8 |
|
8 | 9 |
|
9 | 10 | ##############################################################################
|
@@ -541,9 +542,46 @@ def writeOpenCLKernels():
|
541 | 542 | # Main
|
542 | 543 | ################################################################################
|
543 | 544 | if __name__ == "__main__":
|
544 |
| - if len(sys.argv) == 2: |
545 |
| - Common.setOutputPath(sys.argv[1]) |
| 545 | + ap = argparse.ArgumentParser(description="KernelOpenCL") |
| 546 | + ap.add_argument("precision", choices=["s","d","c","z"], help="precision" ) |
| 547 | + ap.add_argument("order", choices=["row","col"], help="order: row major or column major" ) |
| 548 | + ap.add_argument("transA", choices=["N","T", "C"], help="transA" ) |
| 549 | + ap.add_argument("transB", choices=["N","T", "C"], help="transB" ) |
| 550 | + ap.add_argument("beta", choices=[0, 1], type=int, help="0 for beta is zero, 1 for beta is non-zero" ) |
| 551 | + ap.add_argument("workGroupNumRows", type=int ) |
| 552 | + ap.add_argument("workGroupNumCols", type=int ) |
| 553 | + ap.add_argument("microTileNumRows", type=int ) |
| 554 | + ap.add_argument("microTileNumCols", type=int ) |
| 555 | + ap.add_argument("unroll", type=int, help="number of iterations to unroll the loop over k" ) |
| 556 | + ap.add_argument("outputPath", default=".", help="output path; %s will be appended to path" % Common.getRelativeKernelSourcePath() ) |
| 557 | + |
| 558 | + args = ap.parse_args() |
| 559 | + |
| 560 | + kernel = KernelParameters.KernelParameters() |
| 561 | + kernel.precision = args.precision |
| 562 | + if args.order == "col": |
| 563 | + kernel.order = "clblasColumnMajor" |
546 | 564 | else:
|
547 |
| - print "Warning: No output path specified; default is working directory." |
548 |
| - writeOpenCLKernels() |
| 565 | + kernel.order = "clblasRowMajor" |
| 566 | + kernel.transA = args.transA |
| 567 | + kernel.transB = args.transB |
| 568 | + kernel.beta = args.beta |
| 569 | + kernel.workGroupNumRows = args.workGroupNumRows |
| 570 | + kernel.workGroupNumCols = args.workGroupNumCols |
| 571 | + kernel.microTileNumRows = args.microTileNumRows |
| 572 | + kernel.microTileNumCols = args.microTileNumCols |
| 573 | + kernel.unroll = args.unroll |
| 574 | + Common.setOutputPath(args.outputPath) |
| 575 | + |
| 576 | + kernel.macroTileNumRows = kernel.workGroupNumRows * kernel.microTileNumRows |
| 577 | + kernel.macroTileNumCols = kernel.workGroupNumCols * kernel.microTileNumCols |
| 578 | + |
| 579 | + if not os.path.exists( Common.getKernelSourcePath() ): |
| 580 | + os.makedirs( Common.getKernelSourcePath() ) |
| 581 | + |
| 582 | + writeOpenCLKernelToFile(kernel) |
| 583 | + |
| 584 | + kernelName = kernel.getName() |
| 585 | + kernelFileName = Common.getKernelSourcePath() + kernelName +"_src.cpp" |
| 586 | + print "kernel \"%s\" written to %s" % (kernelName, kernelFileName) |
549 | 587 |
|
0 commit comments