Skip to content

Commit 0773668

Browse files
authored
SPIRV Runner: User driven generic SYCL output buffer copy (D2H) (#2954)
This PR conditionally copies the output buffers (buffer_cnt >= 0) based on the CLI option (-o) with tensor names (comma separated). Previously we used to support single output tensor copy.
1 parent 3ccab57 commit 0773668

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

utils/SPIRVRunner/SPIRVRunner.cpp

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ auto read_spirv(const std::string &filename) {
4747
return read_file_as_bytes(filename);
4848
}
4949

50+
// Host output tensor buffers and indexes
51+
struct TensorBuffer {
52+
torch::Tensor buffer_ptr;
53+
size_t index;
54+
};
55+
5056
// Structure that contains Triton kernel arguments
5157
struct KernelArguments {
5258
int gridX;
@@ -62,11 +68,11 @@ struct KernelArguments {
6268
std::string spv_name;
6369
ordered_json jsonData;
6470
std::vector<char *> dev_buffers;
65-
torch::Tensor host_outbuffer;
66-
std::string out_tensor_name;
71+
std::vector<TensorBuffer> host_outbuffers;
72+
std::vector<std::string> out_tensor_names;
6773
std::string spirv_dump_dir;
6874

69-
KernelArguments(const std::string &outtensorname) {
75+
KernelArguments(const std::vector<std::string> &outtensornames) {
7076
// Check if the triton_xpu_dump path exists if not point to current
7177
// directory
7278
auto env_path = std::getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS");
@@ -97,7 +103,7 @@ struct KernelArguments {
97103
build_flags = jsonData.at("build_flags");
98104
spv_name =
99105
spirv_dump_dir + "/" + jsonData.at("spv_name").get<std::string>();
100-
out_tensor_name = outtensorname;
106+
out_tensor_names = outtensornames;
101107
}
102108
};
103109

@@ -348,13 +354,14 @@ at::TensorOptions getTensorOptions(const std::string &dtype) {
348354
}
349355
}
350356

351-
at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
352-
KernelArguments triton_args, bool get_kernel_time) {
357+
std::vector<TensorBuffer> launchKernel(sycl::queue stream, sycl::kernel kernel,
358+
KernelArguments triton_args,
359+
bool get_kernel_time) {
353360

354361
auto tensor_ptr = [](const torch::Tensor &t) -> void * {
355362
return static_cast<void *>(t.data_ptr());
356363
};
357-
int devout_idx = 0;
364+
358365
for (auto &item : triton_args.jsonData["argument_list"]) {
359366
if (item.contains("type")) {
360367
if (item.at("type").get<std::string>() == "tensor") {
@@ -369,36 +376,40 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
369376
.wait_and_throw();
370377

371378
// Configure output tensor
372-
if (item.at("name").get<std::string>() == triton_args.out_tensor_name) {
373-
devout_idx = triton_args.dev_buffers.size() - 1;
374-
triton_args.host_outbuffer = torch::zeros(
375-
{tensor.sizes()}, getTensorOptions(item.at("dtype")));
376-
std::cout << "Tensor output: " << triton_args.host_outbuffer.sizes()
377-
<< ", " << triton_args.host_outbuffer.scalar_type() << " ("
378-
<< triton_args.host_outbuffer.nbytes() << " bytes)"
379-
<< std::endl;
379+
if (std::find(triton_args.out_tensor_names.begin(),
380+
triton_args.out_tensor_names.end(),
381+
item.at("name").get<std::string>()) !=
382+
triton_args.out_tensor_names.end()) {
383+
TensorBuffer tb;
384+
tb.buffer_ptr = torch::zeros({tensor.sizes()},
385+
getTensorOptions(item.at("dtype")));
386+
tb.index = triton_args.dev_buffers.size() - 1;
387+
triton_args.host_outbuffers.push_back(tb);
388+
std::cout
389+
<< "Tensor output[" << triton_args.host_outbuffers.back().index
390+
<< "]: " << triton_args.host_outbuffers.back().buffer_ptr.sizes()
391+
<< ", "
392+
<< triton_args.host_outbuffers.back().buffer_ptr.scalar_type()
393+
<< " (" << triton_args.host_outbuffers.back().buffer_ptr.nbytes()
394+
<< " bytes)" << std::endl;
380395
}
381396
}
382397
} else {
383398
throw std::runtime_error("Type entry is missing in JSON argument_list");
384399
}
385400
}
386401

387-
if (!triton_args.host_outbuffer.defined()) {
388-
std::string message = "Output tensor isn't configured; \
389-
the second positional parameter is ";
390-
throw std::runtime_error(message + triton_args.out_tensor_name);
391-
}
392-
393402
// Launch SYCL kernel
394403
sycl_kernel_launch(stream, kernel, triton_args, get_kernel_time);
395404

396-
// copy back
397-
stream
398-
.memcpy(tensor_ptr(triton_args.host_outbuffer),
399-
triton_args.dev_buffers.at(devout_idx),
400-
triton_args.host_outbuffer.nbytes())
401-
.wait_and_throw();
405+
// copy back the output tensors
406+
for (const auto &item : triton_args.host_outbuffers) {
407+
stream
408+
.memcpy(tensor_ptr(item.buffer_ptr),
409+
triton_args.dev_buffers.at(item.index),
410+
item.buffer_ptr.nbytes())
411+
.wait_and_throw();
412+
}
402413

403414
for (auto *dev_ptr : triton_args.dev_buffers) {
404415
if (dev_ptr)
@@ -407,7 +418,7 @@ at::Tensor launchKernel(sycl::queue stream, sycl::kernel kernel,
407418
throw std::runtime_error("sycl::free failed \n");
408419
}
409420

410-
return triton_args.host_outbuffer;
421+
return triton_args.host_outbuffers;
411422
}
412423

413424
int main(int argc, char **argv) {
@@ -430,7 +441,7 @@ int main(int argc, char **argv) {
430441
initDevices(&q);
431442

432443
// Parse the JSON file and create argument dictionary
433-
KernelArguments tritonArgDict(cliopts.output_tensor);
444+
KernelArguments tritonArgDict(cliopts.output_tensors);
434445

435446
// read spirv
436447
auto spirv = read_spirv(tritonArgDict.spv_name);
@@ -444,12 +455,16 @@ int main(int argc, char **argv) {
444455
std::cout << "Loaded kernel with " << n_regs << " registers and "
445456
<< n_spills << " register spills." << std::endl;
446457

447-
auto output =
458+
auto output_tensors =
448459
launchKernel(q, kernel, tritonArgDict, cliopts.get_kernel_time);
449460

450-
auto output_tensor = tritonArgDict.spirv_dump_dir + "/cpp_outs.pt";
451-
write_tensor(output_tensor, output);
452-
std::cout << "Output Tensor Path: " << output_tensor << std::endl;
461+
// Write output tensors to file
462+
for (auto &item : output_tensors) {
463+
auto output_tensor = tritonArgDict.spirv_dump_dir + "/cpp_outs_" +
464+
std::to_string(item.index) + ".pt";
465+
write_tensor(output_tensor, item.buffer_ptr);
466+
std::cout << "Output Tensor Path: " << output_tensor << std::endl;
467+
}
453468
} catch (const std::runtime_error &e) {
454469
std::cerr << "Error: " << e.what() << std::endl;
455470
return EXIT_FAILURE;

utils/SPIRVRunner/llvm_parser.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@ command_line_parser::command_line_parser(int argc, char **argv)
55

66
command_line_parser::options command_line_parser::parse() {
77
options opts;
8-
llvm::cl::opt<std::string> output_tensor(
9-
"o", llvm::cl::desc("<Specify Output Tensor Name>"), llvm::cl::Required);
8+
llvm::cl::list<std::string> output_tensors(
9+
"o",
10+
llvm::cl::desc(
11+
"<Specify Output Tensor Names (Ex: -o tensor_1,tensor_2 or skip)>"),
12+
llvm::cl::CommaSeparated);
1013
llvm::cl::opt<bool> enable_profiling(
1114
"p", llvm::cl::desc("Enable kernel time profiling"),
1215
llvm::cl::init(opts.get_kernel_time));
1316

1417
llvm::cl::ParseCommandLineOptions(argc, argv, "SPIRVRunner\n");
1518

16-
opts.output_tensor = output_tensor;
19+
opts.output_tensors.assign(output_tensors.begin(), output_tensors.end());
1720
opts.get_kernel_time = enable_profiling;
1821

1922
return opts;

utils/SPIRVRunner/llvm_parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class command_line_parser {
77
public:
88
struct options {
9-
std::string output_tensor;
9+
std::vector<std::string> output_tensors;
1010
bool get_kernel_time = false;
1111
};
1212

0 commit comments

Comments
 (0)