|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include <stdexcept> |
| 7 | +#include <string> |
| 8 | +#include <sstream> |
| 9 | +#include <memory> |
| 10 | +#include <cassert> |
| 11 | + |
| 12 | +#include <cuda_runtime.h> |
| 13 | +#include "span.h" |
| 14 | + |
1 | 15 | namespace Generators {
|
2 | 16 |
|
3 | 17 | cudaStream_t GetStream();
|
@@ -91,4 +105,54 @@ cuda_unique_ptr<T> CudaMallocArray(size_t count, std::span<T>* p_span = nullptr)
|
91 | 105 | return cuda_unique_ptr<T>{p};
|
92 | 106 | }
|
93 | 107 |
|
| 108 | +inline int CeilDiv(int a, int b) { return (a + (b - 1)) / b; } |
| 109 | + |
| 110 | +class CudaError : public std::runtime_error { |
| 111 | + public: |
| 112 | + explicit CudaError(const std::string& msg, cudaError_t code) |
| 113 | + : std::runtime_error(msg), code_(code) {} |
| 114 | + |
| 115 | + cudaError_t code() const noexcept { return code_; } |
| 116 | + |
| 117 | + private: |
| 118 | + cudaError_t code_; |
| 119 | +}; |
| 120 | + |
| 121 | +#define CUDA_CHECK(call) \ |
| 122 | + do { \ |
| 123 | + cudaError_t err = (call); \ |
| 124 | + if (err != cudaSuccess) { \ |
| 125 | + std::stringstream ss; \ |
| 126 | + ss << "CUDA error in " << __func__ << " at " << __FILE__ \ |
| 127 | + << ":" << __LINE__ << " - " << cudaGetErrorString(err); \ |
| 128 | + throw Generators::CudaError(ss.str(), err); \ |
| 129 | + } \ |
| 130 | + } while (0) |
| 131 | + |
| 132 | +#ifdef NDEBUG |
| 133 | +#define CUDA_CHECK_LAUNCH() \ |
| 134 | + do { \ |
| 135 | + cudaError_t err = cudaPeekAtLastError(); \ |
| 136 | + if (err != cudaSuccess) { \ |
| 137 | + std::stringstream ss; \ |
| 138 | + ss << "CUDA launch error in " << __func__ << " at " \ |
| 139 | + << __FILE__ << ":" << __LINE__ << " - " \ |
| 140 | + << cudaGetErrorString(err); \ |
| 141 | + throw Generators::CudaError(ss.str(), err); \ |
| 142 | + } \ |
| 143 | + } while (0) |
| 144 | +#else |
| 145 | +#define CUDA_CHECK_LAUNCH() \ |
| 146 | + do { \ |
| 147 | + cudaError_t err = cudaGetLastError(); \ |
| 148 | + if (err != cudaSuccess) { \ |
| 149 | + std::stringstream ss; \ |
| 150 | + ss << "CUDA launch error in " << __func__ << " at " \ |
| 151 | + << __FILE__ << ":" << __LINE__ << " - " \ |
| 152 | + << cudaGetErrorString(err); \ |
| 153 | + throw Generators::CudaError(ss.str(), err); \ |
| 154 | + } \ |
| 155 | + } while (0) |
| 156 | +#endif |
| 157 | + |
94 | 158 | } // namespace Generators
|
0 commit comments