diff --git a/.gitignore b/.gitignore index 31d4260..2012679 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ .cache/ .vscode/ Data/ +third_party/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index ad4b6fa..157f5ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,20 @@ include(CMakeDependentOption) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) project(infini_train VERSION 0.3.0 LANGUAGES CXX) +# 添加这些行 +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Debug) +endif() + +# 确保Debug模式包含调试信息 +set(CMAKE_CXX_FLAGS_DEBUG "-g -O0") +set(CMAKE_C_FLAGS_DEBUG "-g -O0") + +# 如果使用CUDA,也为CUDA添加调试标志 +if(USE_CUDA) + set(CMAKE_CUDA_FLAGS_DEBUG "-g -O0") +endif() + set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -52,7 +66,18 @@ if(USE_CUDA) target_link_libraries(infini_train_cuda_kernels glog CUDA::cudart CUDA::cublas) add_library(infini_train STATIC ${SRC}) - target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels infini_train_cuda_kernels "-Wl,--no-whole-archive") + # 修改这行:添加CUDA库链接 + target_link_libraries(infini_train + glog + gflags + CUDA::cudart + CUDA::cuda_driver # 添加这个,解决 cuInit 等函数 + CUDA::cublas + "-Wl,--whole-archive" + infini_train_cpu_kernels + infini_train_cuda_kernels + "-Wl,--no-whole-archive" + ) else() add_library(infini_train STATIC ${SRC}) target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels "-Wl,--no-whole-archive") diff --git a/Makefile b/Makefile index c5a12a8..7a70cb1 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ CMAKE_OPT += -DUSE_CUDA=$(USE_CUDA) build: mkdir -p build/$(TYPE) - cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8 + cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j clean: rm -rf build diff --git "a/docs/TinyInfiniTrain \344\275\234\344\270\232\346\212\245\345\221\212.md" "b/docs/TinyInfiniTrain \344\275\234\344\270\232\346\212\245\345\221\212.md" index bc23852..f132719 100644 --- "a/docs/TinyInfiniTrain \344\275\234\344\270\232\346\212\245\345\221\212.md" +++ "b/docs/TinyInfiniTrain \344\275\234\344\270\232\346\212\245\345\221\212.md" @@ -2,6 +2,8 @@ ## 一、test 通过截图 +![通过截图](image.png) + ## 二、作业步骤 > 将代码填入下面代码块中指定位置,并详细描述完成该作业的解决思路和遇到的问题。 @@ -17,26 +19,32 @@ ```c++ std::vector> Neg::Forward(const std::vector> &input_tensors) { // =================================== 作业 =================================== - // TODO:通过Dispatcher获取设备专属kernel,对输入张量进行取反操作 - // HINT: 依赖test_dispatcher,kernel实现已给出 - // =================================== 作业 =================================== + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().Type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "NegForward"}); + return {kernel.Call>(input)}; } std::vector> Neg::Backward(const std::vector> &grad_outputs) { // =================================== 作业 =================================== - // TODO:通过Dispatcher获取设备专属的反向传播kernel,计算梯度 - // HINT: 依赖test_dispatcher,kernel实现已给出 - // =================================== 作业 =================================== + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().Type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "NegBackward"}); + return {kernel.Call>(grad_output)}; } ``` #### 解决思路 - +`input_tensors`作为参数,通过Dispatcher调用NegBackward的计算Kernel获取结果即可 #### 遇到问题 - +无 ### 作业二:实现矩阵乘法 @@ -49,21 +57,98 @@ std::vector> Neg::Backward(const std::vector MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { - // =================================== 作业 =================================== - // TODO:实现CPU上的矩阵乘法前向计算 - // REF: - // =================================== 作业 =================================== +std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { + // =================================== 作业 =================================== + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + // TODO(dcj): support broadcast later + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - std::tuple, std::shared_ptr> - MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { - // =================================== 作业 =================================== - // TODO:实现CPU上的矩阵乘法反向传播 - // REF: - // =================================== 作业 =================================== + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, DataType::kFLOAT32); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + float acc = 0.0f; + for (int64_t p = 0; p < k; ++p) { + acc += static_cast(input->DataPtr())[b * m * k + i * k + p] + * static_cast(other->DataPtr())[b * k * n + p * n + j]; + } + static_cast(output->DataPtr())[b * m * n + i * n + j] = acc; + } + } } + return {output}; +} + +MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, + const std::shared_ptr &grad_output) { + // =================================== 作业 =================================== + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + } + + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); + grad_input->Fill(0.0f); + grad_other->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto input_idx = b * m * k + i * k + p; + const auto other_idx = b * k * n + p * n + j; + static_cast(grad_input->DataPtr())[input_idx] + += grad * static_cast(other->DataPtr())[other_idx]; + static_cast(grad_other->DataPtr())[other_idx] + += grad * static_cast(input->DataPtr())[input_idx]; + } + } + } + } + return {grad_input, grad_other}; +} ``` #### CUDA实现 @@ -73,30 +158,181 @@ std::vector> Neg::Backward(const std::vector MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { - // =================================== 作业 =================================== - // TODO:实现CUDA上的矩阵乘法前向计算 - // REF: - // =================================== 作业 =================================== +std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { + // =================================== 作业 =================================== + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; } - std::tuple, std::shared_ptr> - MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, - const std::shared_ptr &grad_output) { - // =================================== 作业 =================================== - // TODO:实现CUDA上的矩阵乘法反向传播 - // REF: - // =================================== 作业 =================================== + auto dtype = input->Dtype(); + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, dtype, input->GetDevice()); + + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, input->GetDevice().Index())); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = cuda_device->CublasHandle(); + + // cuBLAS is colmun-major + // output = input * other --> output.T = other.T * input.T + // C = A * B ==> output.T[*, n, m] = other.T[*, n, k] * input.T[*, k, m] + // C = output.T[*, n, m] + // A = other.T[*, n, k] + // B = input.T[*, k, m] + int lda = n; + int ldb = k; + int ldc = n; + int64_t stride_a = n * k; + int64_t stride_b = k * m; + int64_t stride_c = m * n; + // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) + + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_32F, + ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_16BF, lda, + stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_16BF, + ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + default: + LOG_UNSUPPORTED_DTYPE(dtype, "CUDA MatmulForward"); + } + + return output; +} + +MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, + const std::shared_ptr &grad_output) { + // =================================== 作业 =================================== + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; } + + auto dtype = input->Dtype(); + auto grad_input = std::make_shared(input_dims, dtype, grad_output->GetDevice()); + auto grad_other = std::make_shared(other_dims, dtype, grad_output->GetDevice()); + + DispatchFunc( + dtype, + [=]() { + grad_input->Fill(0); + grad_other->Fill(0); + }, + "CUDA MatmulBackward"); + + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, input->GetDevice().Index())); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = cuda_device->CublasHandle(); + + { + // cuBLAS is colmun-major + // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T + // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] + // C = grad_input.T[*, k, m] + // A = other.T[*, n, k] + // B = grad_output.T[*, n, m] + const int lda = n, ldb = n, ldc = k; + const int64_t stride_a = k * n; + const int64_t stride_b = n * m; + const int64_t stride_c = m * k; + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, grad_output->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_input->DataPtr(), + CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE( + WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_16BF, lda, stride_a, + grad_output->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_input->DataPtr(), CUDA_R_16BF, ldc, + stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + } + + { + // cuBLAS is colmun-major + // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input + // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] + // C = grad_other.T[*, n, k] + // A = grad_output.T[*, n, m] + // B = input.T[*, k, m] + const int lda = n, ldb = k, ldc = n; + const int64_t stride_a = n * m; + const int64_t stride_b = k * m; + const int64_t stride_c = n * k; + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output->DataPtr(), CUDA_R_32F, + lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_other->DataPtr(), + CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output->DataPtr(), CUDA_R_16BF, + lda, stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_other->DataPtr(), + CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + } + + return {grad_input, grad_other}; +} ``` #### 解决思路 - +主要是复制粘贴了InfiniTrain中提供的代码,CPU实现比较直观,直接使用for循环实现矩阵的乘法,但是唯一的问题在于TinyInfiniTrain中的GetDevice()的返回值与InfiniTrain中GetDevice()的返回值不同 +```c++ + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, input->GetDevice().Index())); +``` #### 遇到问题 - +无 ### 作业三:实现Adam优化器 @@ -113,9 +349,24 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p const std::shared_ptr &m, const std::shared_ptr &v, float learning_rate, float beta1, float beta2, float eps, int64_t t) { // =================================== 作业 =================================== - // TODO:实现Adam优化器的梯度累积和参数更新 - // REF: - // =================================== 作业 =================================== + const float *grad_data = static_cast(grad->DataPtr()); + float *m_data = static_cast(m->DataPtr()); + float *v_data = static_cast(v->DataPtr()); + float *param_data = static_cast(param->DataPtr()); + + const float bias_correction_m = 1.0f - std::pow(beta1, t); + const float bias_correction_v = 1.0f - std::pow(beta2, t); + +#pragma omp parallel for + for (size_t idx = 0; idx < grad->NumElements(); ++idx) { + m_data[idx] = beta1 * m_data[idx] + (1 - beta1) * grad_data[idx]; + v_data[idx] = beta2 * v_data[idx] + (1 - beta2) * grad_data[idx] * grad_data[idx]; + + const float m_hat = m_data[idx] / bias_correction_m; + const float v_hat = v_data[idx] / bias_correction_v; + + param_data[idx] -= learning_rate * m_hat / (std::sqrt(v_hat) + eps); + } } ``` @@ -126,23 +377,61 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p 代码位置:infini_train/src/kernels/cuda/accumulate_grad.cu ```c++ +template +__global__ void AdamAccumulateGradKernel(const T *grad_data, T *param_data, size_t num_elements, T *m_data, T *v_data, + float learning_rate, float beta1, float beta2, float eps, + const float bias_correction_m, const float bias_correction_v) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + m_data[idx] = common::cuda::Fma(common::cuda::Cast(beta1), m_data[idx], + common::cuda::Cast(1 - beta1) * grad_data[idx]); + v_data[idx] = common::cuda::Fma(common::cuda::Cast(beta2), v_data[idx], + common::cuda::Cast(1 - beta2) * grad_data[idx] * grad_data[idx]); + + const float m_hat = common::cuda::Cast(m_data[idx]) / bias_correction_m; + const float v_hat = common::cuda::Cast(v_data[idx]) / bias_correction_v; + + param_data[idx] = common::cuda::Sub( + param_data[idx], common::cuda::Cast(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps))); + } +} + void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_ptr ¶m, const std::shared_ptr &m, const std::shared_ptr &v, float learning_rate, float beta1, float beta2, float eps, int64_t t) { // =================================== 作业 =================================== // TODO:实现Adam优化器的梯度累积和参数更新 - // REF: + // REF: // =================================== 作业 =================================== + size_t num_elements = grad->NumElements(); + + const float bias_correction_m = 1.0f - std::pow(beta1, t); + const float bias_correction_v = 1.0f - std::pow(beta2, t); + + int threads_per_block = 256; + int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, grad->GetDevice().Index())); + + DispatchFunc( + grad->Dtype(), + [=]() { + AdamAccumulateGradKernel<<Stream()>>>( + static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, + static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, + bias_correction_m, bias_correction_v); + }, + "CUDA AdamAccumulateGrad"); } ``` #### 解决思路 - +模仿InfiniTrain中的实现 #### 遇到问题 - +无 ### 作业四:实现Tensor基础操作 @@ -156,10 +445,27 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p ```c++ std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { + // return Contiguous()->View(new_shape); // =================================== 作业 =================================== - // TODO:实现张量扁平化操作,将指定维度范围[start, end]内的所有维度合并为一个维度 - // HINT: - // =================================== 作业 =================================== + auto ndim = dims_.size(); + auto start_dim = start >= 0 ? start : start + ndim; + auto end_dim = end >= 0 ? end : end + ndim; + CHECK(start_dim >= 0 && end_dim >= start_dim && end_dim <= ndim); + + std::vector new_shape; + int64_t flatten_size = 1; + for (int64_t i = 0; i < ndim; ++i) { + if (i < start_dim || i > end_dim) { + new_shape.push_back(dims_[i]); + } else { + flatten_size *= dims_[i]; + if (i == end_dim) { + new_shape.push_back(flatten_size); + } + } + } + + return Contiguous()->View(new_shape); } ``` @@ -174,20 +480,30 @@ std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { ```c++ void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool create_graph) const { // =================================== 作业 =================================== - // TODO:实现自动微分反向传播 - // 功能描述:1. 计算当前张量对叶子节点的梯度 2. 支持多输出场景的梯度累加 - // HINT: - // =================================== 作业 =================================== + CHECK(!retain_graph && !create_graph) << "Not implemented yet!"; + if (grad_fn_) { + if (!gradient) { + CHECK_EQ(dims_.size(), 0); + gradient = std::make_shared(std::vector{}, dtype_, GetDevice()); + gradient->Fill(1.0f); + } else { + CHECK_EQ(static_cast(GetDevice().Type()), static_cast(gradient->GetDevice().Type())); + CHECK_EQ(static_cast(dtype_), static_cast(gradient->Dtype())); + CHECK_EQ(dims_.size(), gradient->Dims().size()); + for (int idx = 0; idx < dims_.size(); ++idx) { CHECK_EQ(dims_[idx], gradient->Dims()[idx]); } + } + grad_fn_->BackwardPartial(gradient, output_idx_); + } } ``` #### 解决思路 - +模仿InfiniTrain中的实现 #### 遇到问题 - +无 ### 作业五 注册算子kernel的实现 @@ -198,35 +514,57 @@ void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool 代码位置:infini_train/include/dispatcher.h ```c++ -template RetT Call(ArgsT... args) const { - // =================================== 作业 =================================== - // TODO:实现通用kernel调用接口 - // 功能描述:将存储的函数指针转换为指定类型并调用 - // HINT: - // =================================== 作业 =================================== -} + template RetT Call(ArgsT... args) const { + // =================================== 作业 =================================== +#ifdef PROFILE_MODE + const auto &ctx = GetProfileContext(); + Profiler::Instance().StartRecord(ctx.name, ctx.device); +#endif + + using FuncT = RetT (*)(ArgsT...); + auto fn = reinterpret_cast(func_ptr_); + + if constexpr (std::is_void_v) { + fn(std::forward(args)...); + +#ifdef PROFILE_MODE + Profiler::Instance().EndRecord(ctx.name, ctx.device); +#endif + return; + } else { + RetT ret = fn(std::forward(args)...); + +#ifdef PROFILE_MODE + Profiler::Instance().EndRecord(ctx.name, ctx.device); +#endif + return ret; + } + } -template void Register(const KeyT &key, FuncT &&kernel) { - // =================================== 作业 =================================== - // TODO:实现kernel注册机制 - // 功能描述:将kernel函数与设备类型、名称绑定 - // =================================== 作业 =================================== -} + template void Register(const KeyT &key, FuncT &&kernel) { + // =================================== 作业 =================================== + // TODO:实现kernel注册机制 + // 功能描述:将kernel函数与设备类型、名称绑定 + // =================================== 作业 =================================== + CHECK(!key_to_kernel_map_.contains(key)) + << "Kernel already registered: " << key.second << " on device: " << static_cast(key.first); + key_to_kernel_map_.emplace(key, kernel); + } -#define REGISTER_KERNEL(device, kernel_name, kernel_func) \ - // =================================== 作业 =================================== - // TODO:实现自动注册宏 - // 功能描述:在全局静态区注册kernel,避免显式初始化代码 - // =================================== 作业 =================================== +#define REGISTER_KERNEL(device, kernel_name, kernel_func) \ + static const bool _##kernel_name##_registered##__COUNTER__ = []() { \ + infini_train::Dispatcher::Instance().Register({device, #kernel_name}, kernel_func); \ + return true; \ + }(); ``` #### 解决思路 - +模仿InfiniTrain中的实现 #### 遇到问题 - +无 ### 作业六:实现GPT-2整体训练 @@ -252,13 +590,54 @@ TinyShakespeareFile ReadTinyShakespeareFile(const std::string &path, size_t sequ | magic(4B) | version(4B) | num_toks(4B) | reserved(1012B) | token数据 | ---------------------------------------------------------------------------------- =================================== 作业 =================================== */ + if (!std::filesystem::exists(path)) { + LOG(FATAL) << "File not found: " << path; + } + + TinyShakespeareFile text_file; + std::ifstream ifs(path, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(1024, &ifs); + const int magic = BytesToType(header, 0); + const int version = BytesToType(header, 4); + const int num_tokens = BytesToType(header, 8); + text_file.type = kTypeMap.at(magic); + + const int num_sequences = num_tokens / sequence_length; + text_file.dims.assign({num_sequences, static_cast(sequence_length)}); + + const int data_size_in_bytes + = kTypeToSize.at(text_file.type) + * std::accumulate(text_file.dims.begin(), text_file.dims.end(), 1, std::multiplies()); + // shape: (num_seq, seq_len), dtype: int64 + text_file.tensor = infini_train::Tensor(text_file.dims, DataType::kINT64); + int64_t *dst = static_cast(text_file.tensor.DataPtr()); + + std::variant, std::vector> buffer; + if (text_file.type == TinyShakespeareType::kUINT16) { + CHECK_LE(sequence_length, 1024); // GPT-2: max_seq_length = 1024 + buffer = std::vector(num_sequences * sequence_length); + } else if (text_file.type == TinyShakespeareType::kUINT32) { + CHECK_LE(sequence_length, 8192); // LLaMA-3: max_seq_length = 8192 + buffer = std::vector(num_sequences * sequence_length); + } + std::visit( + [&](auto &vec) { + ifs.read(reinterpret_cast(vec.data()), data_size_in_bytes); + for (size_t i = 0; i < vec.size(); ++i) { dst[i] = static_cast(vec[i]); } + }, + buffer); + return text_file; } -TinyShakespeareDataset::TinyShakespeareDataset(const std::string &filepath, size_t sequence_length) { +TinyShakespeareDataset::TinyShakespeareDataset(const std::string &filepath, size_t sequence_length) + : text_file_(ReadTinyShakespeareFile(filepath, sequence_length)), sequence_length_(sequence_length), + sequence_size_in_bytes_(sequence_length * sizeof(int64_t)), num_samples_(text_file_.dims[0] - 1) { // =================================== 作业 =================================== // TODO:初始化数据集实例 // HINT: 调用ReadTinyShakespeareFile加载数据文件 // =================================== 作业 =================================== + CHECK_EQ(text_file_.dims[1], sequence_length_); + CHECK_EQ(static_cast(text_file_.tensor.Dtype()), static_cast(DataType::kINT64)); } ``` @@ -277,6 +656,41 @@ Tokenizer::Tokenizer(const std::string &filepath) { | magic(4B) | version(4B) | vocab_size(4B) | reserved(1012B) | token词表数据 | ---------------------------------------------------------------------------------- ===================================== 作业 ===================================== */ + if (!std::filesystem::exists(filepath)) { + LOG(FATAL) << "File not found: " << filepath; + } + + std::ifstream ifs(filepath, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(1024, &ifs); + + magic_number_ = BytesToType(header, 0); + const uint32_t version_num = BytesToType(header, 4); + vocab_size_ = BytesToType(header, 8); + if (kEotMap.find(magic_number_) == kEotMap.end()) { + LOG(FATAL) << "Unsupported tokenizer magic: " << magic_number_; + } + + Version version = static_cast(version_num); + if (version == Version::kV1) { + eot_token_ = kEotMap.at(magic_number_); + } else if (version == Version::kV2) { + const uint32_t eot_token_2 = BytesToType(header, 12); + eot_token_ = eot_token_2; + } else { + LOG(FATAL) << "Unsupported tokenizer version: " << version_num; + return; + } + + token_table_.resize(vocab_size_); + for (uint32_t i = 0; i < vocab_size_; ++i) { + uint8_t length; + ifs.read(reinterpret_cast(&length), sizeof(length)); + + std::vector buffer(length); + ifs.read(buffer.data(), length); + + token_table_[i] = std::string(buffer.data(), length); + } } ``` @@ -286,19 +700,54 @@ std::string Tokenizer::Decode(uint32_t token_id) const { TODO:实现token_id到文本的转换 功能描述:根据token_id返回对应的文本片段 ===================================== 作业 ===================================== */ + if (token_id >= vocab_size_) { + return "[INVALID_TOKEN]"; + } + return token_table_[token_id]; } ``` ```c++ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, uint32_t text_length, Device device) const { - /* ...原代码... */ + std::vector dims; + dims.assign({batch_size, sequence_length}); + // x_tensor (FLAGS_batch_size, FLAGS_sequence_length) eq:(4, 64) + infini_train::Tensor x_tensor = infini_train::Tensor(dims, DataType::kINT64); + int64_t *x_buff = static_cast(x_tensor.DataPtr()); + for (int i = 0; i < batch_size * sequence_length; ++i) { x_buff[i] = eot_token_; } + + // Give some contexts: "The meaning of life is " + auto prompt = kPromptMap.at(magic_number_); + auto prompt_len = prompt.size(); + for (int i = 0; i < prompt_len; ++i) { x_buff[i] = prompt[i]; } + std::cout << "The meaning of life is"; + + auto x = std::make_shared(x_tensor.To(device)); + uint64_t kRngState = kRngState; LOG(INFO) << "start generate text:"; + + infini_train::Device cpu_device = infini_train::Device(infini_train::DeviceType::kCPU, 0); for (int t = prompt_len; t < text_length; t++) { /* ===================================== 作业 ===================================== TODO:实现单步文本生成逻辑 HINT:调用model.Forward推理获取logits,根据推理结果进行随机采样,调用Decode获取文本结果 ===================================== 作业 ===================================== */ + x = std::make_shared(x->To(device)); // CPU->calc device + // TODO(jym): use no_grad forward later + auto logits = model.Forward({x})[0]; + auto logits_orignal = nn::function::Softmax(logits, -1); + auto logits_cpu = logits_orignal->To(cpu_device); + auto data = logits_cpu.DataPtr(); + auto vocab_size = logits->Dims()[2]; + float *probs = static_cast(data) + (t - 1) * vocab_size; + float coin = RandomF32(kRngState); + int next_token = SampleMult(probs, vocab_size, coin); + + x = std::make_shared(x->To(cpu_device)); // calc device->CPU + auto data_temp = static_cast(x->DataPtr()); + data_temp[t] = next_token; + std::cout << Decode(next_token); } std::cout << std::endl; } @@ -306,7 +755,42 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz #### 解决思路 - +模仿InfiniTrian中的实现 #### 遇到问题 +每当我训练无伦之后就会遇到一个奇怪的问题 + +```bash +Running main() from /home/eq/TinyInfiniTrain/third_party/googletest/googletest/src/gtest_main.cc +[==========] Running 1 test from 1 test suite. +[----------] Global test environment set-up. +[----------] 1 test from GPT2TrainingTest +[ RUN ] GPT2TrainingTest.LogitsConsistency +WARNING: Logging before InitGoogleLogging() is written to STDERR +E20250807 15:20:50.757825 138590656237568 net.cc:268] magic: 20240326 version: 3 block_size: 1024 vocab_size: 50257 n_layer: 12 n_head: 12 n_embd: 768 padded_vocab_size: 50304 +I20250807 15:20:51.473057 138590656237568 test_gpt2.cc:123] Initialize() finished! +I20250807 15:20:51.473109 138590656237568 test_gpt2.cc:207] epoch: 0 +I20250807 15:20:51.999021 138590656237568 test_gpt2.cc:207] epoch: 1 +I20250807 15:20:52.459228 138590656237568 test_gpt2.cc:207] epoch: 2 +I20250807 15:20:52.898673 138590656237568 test_gpt2.cc:207] epoch: 3 +I20250807 15:20:53.339709 138590656237568 test_gpt2.cc:207] epoch: 4 +I20250807 15:20:53.797280 138590656237568 test_gpt2.cc:207] epoch: 5 +I20250807 15:20:54.212782 138590656237568 test_gpt2.cc:207] epoch: 6 +unknown file: Failure +C++ exception with description "parallel_for failed: cudaErrorInvalidDevice: invalid device ordinal" thrown in the test body. + +[ FAILED ] GPT2TrainingTest.LogitsConsistency (46613 ms) +[----------] 1 test from GPT2TrainingTest (46613 ms total) + +[----------] Global test environment tear-down +[==========] 1 test from 1 test suite ran. (46613 ms total) +[ PASSED ] 0 tests. +[ FAILED ] 1 test, listed below: +[ FAILED ] GPT2TrainingTest.LogitsConsistency + + 1 FAILED TEST +``` + +内存和GPU的记录图如下: +![记录](./record.png) \ No newline at end of file diff --git a/docs/image.png b/docs/image.png new file mode 100644 index 0000000..4d38e7e Binary files /dev/null and b/docs/image.png differ diff --git a/docs/record.png b/docs/record.png new file mode 100644 index 0000000..2367db6 Binary files /dev/null and b/docs/record.png differ diff --git a/example/common/tiny_shakespeare_dataset.cc b/example/common/tiny_shakespeare_dataset.cc index 3bc5f1b..1f545f6 100644 --- a/example/common/tiny_shakespeare_dataset.cc +++ b/example/common/tiny_shakespeare_dataset.cc @@ -61,14 +61,55 @@ TinyShakespeareFile ReadTinyShakespeareFile(const std::string &path, size_t sequ | magic(4B) | version(4B) | num_toks(4B) | reserved(1012B) | token数据 | ---------------------------------------------------------------------------------- =================================== 作业 =================================== */ + if (!std::filesystem::exists(path)) { + LOG(FATAL) << "File not found: " << path; + } + + TinyShakespeareFile text_file; + std::ifstream ifs(path, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(1024, &ifs); + const int magic = BytesToType(header, 0); + const int version = BytesToType(header, 4); + const int num_tokens = BytesToType(header, 8); + text_file.type = kTypeMap.at(magic); + + const int num_sequences = num_tokens / sequence_length; + text_file.dims.assign({num_sequences, static_cast(sequence_length)}); + + const int data_size_in_bytes + = kTypeToSize.at(text_file.type) + * std::accumulate(text_file.dims.begin(), text_file.dims.end(), 1, std::multiplies()); + // shape: (num_seq, seq_len), dtype: int64 + text_file.tensor = infini_train::Tensor(text_file.dims, DataType::kINT64); + int64_t *dst = static_cast(text_file.tensor.DataPtr()); + + std::variant, std::vector> buffer; + if (text_file.type == TinyShakespeareType::kUINT16) { + CHECK_LE(sequence_length, 1024); // GPT-2: max_seq_length = 1024 + buffer = std::vector(num_sequences * sequence_length); + } else if (text_file.type == TinyShakespeareType::kUINT32) { + CHECK_LE(sequence_length, 8192); // LLaMA-3: max_seq_length = 8192 + buffer = std::vector(num_sequences * sequence_length); + } + std::visit( + [&](auto &vec) { + ifs.read(reinterpret_cast(vec.data()), data_size_in_bytes); + for (size_t i = 0; i < vec.size(); ++i) { dst[i] = static_cast(vec[i]); } + }, + buffer); + return text_file; } } // namespace -TinyShakespeareDataset::TinyShakespeareDataset(const std::string &filepath, size_t sequence_length) { +TinyShakespeareDataset::TinyShakespeareDataset(const std::string &filepath, size_t sequence_length) + : text_file_(ReadTinyShakespeareFile(filepath, sequence_length)), sequence_length_(sequence_length), + sequence_size_in_bytes_(sequence_length * sizeof(int64_t)), num_samples_(text_file_.dims[0] - 1) { // =================================== 作业 =================================== // TODO:初始化数据集实例 // HINT: 调用ReadTinyShakespeareFile加载数据文件 // =================================== 作业 =================================== + CHECK_EQ(text_file_.dims[1], sequence_length_); + CHECK_EQ(static_cast(text_file_.tensor.Dtype()), static_cast(DataType::kINT64)); } std::pair, std::shared_ptr> diff --git a/example/common/tokenizer.cc b/example/common/tokenizer.cc index 23b9537..8684880 100644 --- a/example/common/tokenizer.cc +++ b/example/common/tokenizer.cc @@ -78,6 +78,41 @@ Tokenizer::Tokenizer(const std::string &filepath) { | magic(4B) | version(4B) | vocab_size(4B) | reserved(1012B) | token词表数据 | ---------------------------------------------------------------------------------- ===================================== 作业 ===================================== */ + if (!std::filesystem::exists(filepath)) { + LOG(FATAL) << "File not found: " << filepath; + } + + std::ifstream ifs(filepath, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(1024, &ifs); + + magic_number_ = BytesToType(header, 0); + const uint32_t version_num = BytesToType(header, 4); + vocab_size_ = BytesToType(header, 8); + if (kEotMap.find(magic_number_) == kEotMap.end()) { + LOG(FATAL) << "Unsupported tokenizer magic: " << magic_number_; + } + + Version version = static_cast(version_num); + if (version == Version::kV1) { + eot_token_ = kEotMap.at(magic_number_); + } else if (version == Version::kV2) { + const uint32_t eot_token_2 = BytesToType(header, 12); + eot_token_ = eot_token_2; + } else { + LOG(FATAL) << "Unsupported tokenizer version: " << version_num; + return; + } + + token_table_.resize(vocab_size_); + for (uint32_t i = 0; i < vocab_size_; ++i) { + uint8_t length; + ifs.read(reinterpret_cast(&length), sizeof(length)); + + std::vector buffer(length); + ifs.read(buffer.data(), length); + + token_table_[i] = std::string(buffer.data(), length); + } } std::string Tokenizer::Decode(uint32_t token_id) const { @@ -85,7 +120,10 @@ std::string Tokenizer::Decode(uint32_t token_id) const { TODO:实现token_id到文本的转换 功能描述:根据token_id返回对应的文本片段 ===================================== 作业 ===================================== */ - return ""; + if (token_id >= vocab_size_) { + return "[INVALID_TOKEN]"; + } + return token_table_[token_id]; } void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_size, uint32_t sequence_length, @@ -106,11 +144,28 @@ void Tokenizer::GenerateText(infini_train::nn::Module &model, uint32_t batch_siz auto x = std::make_shared(x_tensor.To(device)); uint64_t kRngState = kRngState; LOG(INFO) << "start generate text:"; + + infini_train::Device cpu_device = infini_train::Device(infini_train::DeviceType::kCPU, 0); for (int t = prompt_len; t < text_length; t++) { /* ===================================== 作业 ===================================== TODO:实现单步文本生成逻辑 HINT:调用model.Forward推理获取logits,根据推理结果进行随机采样,调用Decode获取文本结果 ===================================== 作业 ===================================== */ + x = std::make_shared(x->To(device)); // CPU->calc device + // TODO(jym): use no_grad forward later + auto logits = model.Forward({x})[0]; + auto logits_orignal = nn::function::Softmax(logits, -1); + auto logits_cpu = logits_orignal->To(cpu_device); + auto data = logits_cpu.DataPtr(); + auto vocab_size = logits->Dims()[2]; + float *probs = static_cast(data) + (t - 1) * vocab_size; + float coin = RandomF32(kRngState); + int next_token = SampleMult(probs, vocab_size, coin); + + x = std::make_shared(x->To(cpu_device)); // calc device->CPU + auto data_temp = static_cast(x->DataPtr()); + data_temp[t] = next_token; + std::cout << Decode(next_token); } std::cout << std::endl; } diff --git a/infini_train/include/common/common.h b/infini_train/include/common/common.h new file mode 100644 index 0000000..5b16dea --- /dev/null +++ b/infini_train/include/common/common.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +#include "glog/logging.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" +#include "infini_train/include/tensor.h" + +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) +#define LOG_LOC(LEVEL, MSG) LOG(LEVEL) << MSG << " at " << __FILE__ << ":" << __LINE__ +#define LOG_UNSUPPORTED_DTYPE(DTYPE, CONTEXT_IDENTIFIER) \ + LOG_LOC(FATAL, WRAP(CONTEXT_IDENTIFIER << ": Unsupported data type: " \ + + kDataTypeToDesc.at(static_cast(dtype)))) diff --git a/infini_train/include/common/cuda/common_cuda.cuh b/infini_train/include/common/cuda/common_cuda.cuh new file mode 100644 index 0000000..812fcb1 --- /dev/null +++ b/infini_train/include/common/cuda/common_cuda.cuh @@ -0,0 +1,285 @@ +#pragma once + +#include "cuda.h" +#include "cuda_runtime.h" +#include +#include + +#include "../common.h" +#include "infini_train/include/dispatcher.h" +#ifdef USE_NCCL +#include "nccl.h" +#endif + +namespace infini_train::common::cuda { + +// Common CUDA Macros +#define CUDA_CHECK(call) \ + do { \ + cudaError_t status = call; \ + if (status != cudaSuccess) { \ + LOG(FATAL) << "CUDA Error: " << cudaGetErrorString(status) << " at " << __FILE__ << ":" << __LINE__; \ + } \ + } while (0) + +#define CUBLAS_CHECK(call) \ + do { \ + cublasStatus_t status = call; \ + if (status != CUBLAS_STATUS_SUCCESS) { \ + LOG(FATAL) << "CUBLAS Error: " << cublasGetStatusString(status) << " at " << __FILE__ << ":" << __LINE__; \ + } \ + } while (0) + +#define CUDA_DRIVER_CHECK(call) \ + do { \ + CUresult status = call; \ + if (status != CUresult::CUDA_SUCCESS) { \ + const char *err_str = nullptr; \ + cuGetErrorString(status, &err_str); \ + LOG(FATAL) << "CUDA Driver API error: " << #call << " failed with error (" << status \ + << "): " << (err_str ? err_str : "Unknown error"); \ + } \ + } while (0) + +#ifdef USE_NCCL +#define NCCL_CHECK(expr) \ + do { \ + ncclResult_t _status = (expr); \ + if (_status != ncclSuccess) { \ + LOG(FATAL) << "NCCL error: " << ncclGetErrorString(_status) << " at " << __FILE__ << ":" << __LINE__ \ + << " (" << #expr << ")"; \ + } \ + } while (0) +#endif + +/** + * Converts a value between arbitrary types with specialized handling for + * CUDA floating-point precisions. For primitive types, this offers perfect + * forwarding which preserves value categories (lvalues/rvalues) + * + * @tparam DST Destination type (deduced) + * @tparam SRC Source type (deduced) + * @param x Input value (preserves const/volatile and value category) + * @return Value converted to DST type + * + * Example: + * half h = Cast(3.14f); // float -> half (CUDA intrinsic) + * float f = Cast(h); // half -> float (CUDA intrinsic) + * int i = Cast(2.718); // double -> int (standard cast) + */ +// TODO(lzm): add support for half and nv_bfloat16 conversions with integral types +template __host__ __device__ DST Cast(SRC &&x) { + static_assert(!std::is_reference_v, "Cast cannot return reference types"); + + using SRC_base = std::remove_cv_t>; + using DST_base = std::remove_cv_t>; + + // nv_bfloat16 conversions + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return __bfloat162float(x); + } else if constexpr (std::is_same_v) { + return static_cast(__bfloat162float(x)); + } else if constexpr (std::is_same_v) { + return __half(x); + } + } + // half conversions + else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return __half2float(x); + } else if constexpr (std::is_same_v) { + return static_cast(__half2float(x)); + } else if constexpr (std::is_same_v) { + return __nv_bfloat16(x); + } + } + // float conversions to reduced precision + else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return __float2bfloat16(x); + } else if constexpr (std::is_same_v) { + return __float2half(x); + } + } + // double conversions to reduced precision + else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return __double2bfloat16(x); + } else if constexpr (std::is_same_v) { + return __double2half(x); + } + } + // Fallback for all other conversions + return (DST)(std::forward(x)); +} + +template __device__ __forceinline__ T Neg(const T &x) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hneg(x); + } else { + return -x; + } +} + +template __device__ __forceinline__ T Reciprocal(const T &x) { + if constexpr (std::is_same_v) { + return __hdiv(__float2half(1.0f), x); + } else if constexpr (std::is_same_v) { + return __hdiv(__float2bfloat16(1.0f), x); + } else { + return T(1) / x; + } +} + +template __device__ __forceinline__ T Sin(const T &x) { + if constexpr (std::is_same_v) { + return __float2half(__sinf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(__sinf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __sinf(x); + } else { + return std::sin(x); + } +} + +template __device__ __forceinline__ T Cos(const T &x) { + if constexpr (std::is_same_v) { + return __float2half(__cosf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(__cosf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __cosf(x); + } else { + return std::cos(x); + } +} + +template __device__ __forceinline__ T Tanh(const T &x) { + if constexpr (std::is_same_v || std::is_same_v) { + return htanh(x); + } else if constexpr (std::is_same_v) { + return tanhf(x); + } else { + return std::tanh(x); + } +} + +template __device__ __forceinline__ T Pow(const T &x, const T &exponent) { + if constexpr (std::is_same_v) { + float x_ = __bfloat162float(x); + float exponent_ = __bfloat162float(exponent); + float ans_f = __powf(x_, exponent_); + return __float2bfloat16(__isnan(ans_f) ? std::pow(x_, exponent_) : ans_f); + } else if constexpr (std::is_same_v) { + float x_ = __half2float(x); + float exponent_ = __half2float(exponent); + float ans_f = __powf(x_, exponent_); + return __float2half(__isnan(ans_f) ? std::pow(x_, exponent_) : ans_f); + } else if constexpr (std::is_same_v) { + return powf(x, exponent); + } else { + return std::pow(x, exponent); + } +} + +template __device__ __forceinline__ T Rsqrt(const T &x) { + if constexpr (std::is_same_v) { + return __float2half(rsqrtf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(rsqrtf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return rsqrtf(x); + } else { + return T(1) / std::sqrt(T(x)); + } +} + +template __device__ __forceinline__ T Log(const T &x) { + if constexpr (std::is_same_v) { + return __float2bfloat16(__logf(__bfloat162float(x))); + } else if constexpr (std::is_same_v) { + return __float2half(__logf(__half2float(x))); + } else if constexpr (std::is_same_v) { + return __logf(x); + } else { + return std::log(x); + } +} + +template __device__ __forceinline__ T Add(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hadd(a, b); + } else { + return a + b; + } +} + +template __device__ __forceinline__ T Sub(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hsub(a, b); + } else { + return a - b; + } +} + +template __device__ __forceinline__ T Mul(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hmul(a, b); + } else { + return a * b; + } +} + +template __device__ __forceinline__ T Div(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hdiv(a, b); + } else { + return a / b; + } +} + +template __device__ __forceinline__ T Sigmoid(const T &x) { + if constexpr (std::is_same_v) { + return 1.0f / (1.0f + expf(-x)); + } else if constexpr (std::is_same_v || std::is_same_v) { + return __hdiv(T(1), T(1) + hexp(-x)); + } else { + return T(1) / (T(1) + std::exp(-x)); + } +} + +template __device__ __forceinline__ T Max(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hle(a, b) ? b : a; + } else if constexpr (std::is_same_v) { + return fmaxf(a, b); + } else { + return std::max(a, b); + } +} + +template __device__ __forceinline__ T Min(const T &a, const T &b) { + if constexpr (std::is_same_v || std::is_same_v) { + return __hle(a, b) ? a : b; + } else if constexpr (std::is_same_v) { + return fminf(a, b); + } else { + return std::min(a, b); + } +} + +template __device__ __forceinline__ T Fma(const T &x, const T &y, const T &z) { + if constexpr (std::is_same_v) { + return __hfma(x, y, z); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(__fmaf_rn(__bfloat162float(x), __bfloat162float(y), __bfloat162float(z))); + } else if constexpr (std::is_same_v) { + return __fmaf_rn(x, y, z); + } else { + return std::fma(x, y, z); + } +} + +} // namespace infini_train::common::cuda diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h new file mode 100644 index 0000000..70d712b --- /dev/null +++ b/infini_train/include/datatype.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#include +#endif + +namespace infini_train { +enum class DataType : int8_t { + kUINT8, + kINT8, + kUINT16, + kINT16, + kUINT32, + kINT32, + kUINT64, + kINT64, + kBFLOAT16, + kFLOAT16, + kFLOAT32, + kFLOAT64, +}; + +inline const std::unordered_map kDataTypeToSize = { + {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, + {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, + {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, +}; + +inline const std::unordered_map kDataTypeToDesc = { + {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, + {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, + {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, + {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, +}; + +/** + * Compile-time type mapping from DataType enum to concrete C++ types. + * + * - Primary template: Declared but undefined to enforce specialization + * - Specializations: Explicit mappings (DataType::kFLOAT32 → float, etc) + * - TypeMap_t alias: Direct access to mapped type (TypeMap_t → int32_t) + * + * Enables type-safe generic code where operations dispatch based on DataType tokens, + * with zero runtime overhead. Extend by adding new specializations. + */ +template struct TypeMap; +template using TypeMap_t = typename TypeMap::type; + +template <> struct TypeMap { + using type = uint8_t; +}; +template <> struct TypeMap { + using type = int8_t; +}; +template <> struct TypeMap { + using type = uint16_t; +}; +template <> struct TypeMap { + using type = int16_t; +}; +template <> struct TypeMap { + using type = uint32_t; +}; +template <> struct TypeMap { + using type = int32_t; +}; +template <> struct TypeMap { + using type = uint64_t; +}; +template <> struct TypeMap { + using type = int64_t; +}; +template <> struct TypeMap { + using type = float; +}; +template <> struct TypeMap { + using type = double; +}; +template <> struct TypeMap { +#ifdef USE_CUDA + using type = nv_bfloat16; +#else + using type = uint16_t; +#endif +}; +template <> struct TypeMap { +#ifdef USE_CUDA + using type = half; +#else + using type = uint16_t; +#endif +}; +} // namespace infini_train diff --git a/infini_train/include/device.h b/infini_train/include/device.h index ebc49e3..0272ae6 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -1,9 +1,19 @@ #pragma once #include +#include #include "glog/logging.h" +#ifdef USE_CUDA +#include "cublas_v2.h" +#include "cuda.h" +#include "cuda_runtime_api.h" +#endif +#ifdef USE_NCCL +#include "nccl.h" +#endif + namespace infini_train { enum class DeviceType : int8_t { kCPU = 0, @@ -26,13 +36,71 @@ class Device { bool IsCPU() const; bool IsCUDA() const; + virtual void SetDevice() const {} + virtual void Synchronize() const {} + std::string ToString() const; friend std::ostream &operator<<(std::ostream &os, const Device &device); -private: +protected: DeviceType type_; int8_t index_; }; +class CpuDevice : public Device { +private: + CpuDevice(); + + friend class DeviceManager; +}; + +#ifdef USE_CUDA +class CudaDevice : public Device { +public: + ~CudaDevice(); + + void SetDevice() const override; + void Synchronize() const override; + + cudaStream_t Stream() const; + + cublasHandle_t CublasHandle() const; +#ifdef USE_NCCL + ncclComm_t NcclComm() const; +#endif + +private: + CudaDevice(int8_t index); + + cudaStream_t stream_ = nullptr; + + cublasHandle_t cublas_handle_ = nullptr; +#ifdef USE_NCCL + ncclComm_t nccl_comm_ = nullptr; +#endif + + friend class DeviceManager; +}; +#endif + +class DeviceManager { +public: + static const DeviceManager *Instance(); + + const Device *GetDevice(DeviceType type, int8_t index = 0) const; + + const Device *GetDefaultDevice() const; + + std::vector GetAllAvailableDevices(DeviceType device_type) const; + +private: + DeviceManager(); + +#ifdef USE_NCCL + void InitNcclCommunicators(); +#endif + + std::unordered_map>> devices_map_; +}; } // namespace infini_train diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 5b91d85..77eea14 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -7,21 +7,400 @@ #include "glog/logging.h" +#include "infini_train/include/common/common.h" #include "infini_train/include/device.h" +/** + * General Utility Macros + */ +#define EXPAND(X) X +// This macro lets you pass an arbitrary expression that may contain internal +// commas to another macro without having the commas causing the expression +// to be interpreted as being multiple arguments +// Basically an alternative for __VA_OPTS__ before C++20 +// ref: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Dispatch_v2.h +#define WRAP(...) __VA_ARGS__ +#define CAT(a, b) CAT_(a, b) +#define CAT_(a, b) a##b + +/** + * Data Type Macros + * Defines common categories of data types for dispatching + */ +#define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64 +#define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16 +#define INFINI_ALL_FLOATING_TYPES EXPAND(INFINI_FLOATING_TYPES), EXPAND(INFINI_REDUCED_FLOATING_TYPES) +#define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64 +#define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64 +#define INFINI_ALL_INTEGRAL_TYPES EXPAND(INFINI_SIGNED_INTEGRAL_TYPES), EXPAND(INFINI_UNSIGNED_INTEGRAL_TYPES) +#define INFINI_ALL_TYPES EXPAND(INFINI_ALL_FLOATING_TYPES), EXPAND(INFINI_ALL_INTEGRAL_TYPES) +#define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8 +#define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16 +#define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32 +#define INFINI_64_BIT_TYPES DataType::kINT64, DataType::kUINT64, DataType::kFLOAT64 + +/** + * Dispatch Macros + */ +#define DISPATCH_WITH_DEFAULT(DTYPE_EXPR, BODY, DEFAULT_BODY, ...) \ + switch (DTYPE_EXPR) { \ + CAT(DISPATCH_CASE_, PP_NARG(__VA_ARGS__))(__VA_ARGS__, WRAP(BODY)) default : { WRAP(DEFAULT_BODY); } \ + } + +// dispatch with switch and arbitrary number of cases +#define DISPATCH(DTYPE_EXPR, BODY, ...) \ + DISPATCH_WITH_DEFAULT( \ + DTYPE_EXPR, WRAP(BODY), \ + EXPAND(LOG(FATAL) << "Unsupported data type at " << __FILE__ << ":" << __LINE__; return nullptr;), \ + __VA_ARGS__) + +// dispatch a single case +#define DISPATCH_CASE(BODY, ...) CAT(DISPATCH_CASE_, PP_NARG(__VA_ARGS__))(__VA_ARGS__, WRAP(BODY)) + +// Helper macros to count the number of arguments +#define PP_NARG(...) PP_NARG_(__VA_ARGS__, PP_RSEQ_N()) +#define PP_NARG_(...) PP_ARG_N(__VA_ARGS__) +#define PP_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, \ + _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, \ + _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, \ + _63, N, ...) \ + N +#define PP_RSEQ_N() \ + 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, \ + 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, \ + 7, 6, 5, 4, 3, 2, 1, 0 + +// Macros to generate case labels +// Should have up to number of DataType cases (currently 12) +#define DISPATCH_CASE_1(T1, BODY) \ + case T1: { \ + BODY break; \ + } + +#define DISPATCH_CASE_2(T1, T2, BODY) \ + case T1: \ + case T2: { \ + BODY break; \ + } + +#define DISPATCH_CASE_3(T1, T2, T3, BODY) \ + case T1: \ + case T2: \ + case T3: { \ + BODY break; \ + } + +#define DISPATCH_CASE_4(T1, T2, T3, T4, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: { \ + BODY break; \ + } + +#define DISPATCH_CASE_5(T1, T2, T3, T4, T5, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: { \ + BODY break; \ + } + +#define DISPATCH_CASE_6(T1, T2, T3, T4, T5, T6, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: { \ + BODY break; \ + } + +#define DISPATCH_CASE_7(T1, T2, T3, T4, T5, T6, T7, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: { \ + BODY break; \ + } + +#define DISPATCH_CASE_8(T1, T2, T3, T4, T5, T6, T7, T8, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: \ + case T8: { \ + BODY break; \ + } + +#define DISPATCH_CASE_9(T1, T2, T3, T4, T5, T6, T7, T8, T9, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: \ + case T8: \ + case T9: { \ + BODY break; \ + } + +#define DISPATCH_CASE_10(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: \ + case T8: \ + case T9: \ + case T10: { \ + BODY break; \ + } + +#define DISPATCH_CASE_11(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: \ + case T8: \ + case T9: \ + case T10: \ + case T11: { \ + BODY break; \ + } + +#define DISPATCH_CASE_12(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, BODY) \ + case T1: \ + case T2: \ + case T3: \ + case T4: \ + case T5: \ + case T6: \ + case T7: \ + case T8: \ + case T9: \ + case T10: \ + case T11: \ + case T12: { \ + BODY break; \ + } + namespace infini_train { + +template struct DataTypeList {}; + +template struct IsDataTypeInList; + +template +struct IsDataTypeInList> : std::disjunction...> {}; + +template +inline constexpr bool IsDataTypeInList_v = IsDataTypeInList::value; + +// function to check if a type is in a list of types +template inline constexpr bool IsTypeInList = (std::is_same_v || ...); + +/** + * @brief Dispatches a functor call based on runtime DataType, restricted to specified allowed types. + * + * This function: + * 1. Maps runtime DataType to compile-time C++ types using TypeMap_t + * 2. Only processes types specified in AllowedDTypes template parameter + * 3. Calls functor with resolved type and forwarded arguments + * + * @tparam AllowedDTypes List of DataType enums to support + * @param dtype Runtime data type to dispatch + * @param func Templated functor to call (must accept operator()) + * @param context_identifier Optional string for context in error messages + * @param args Arguments to be forwarded to the functor + * + * Behavior: + * - For allowed types: Instantiates functor with mapped C++ type + * - For disallowed and unknown types: Logs error and returns + * + * @see TypeMap for DataType to C++ type mapping + */ +template +auto DispatchFunc(DataType dtype, Functor &&func, std::string_view context_identifier = "", Args &&...args) { + switch (dtype) { + +#define CASE_FOR_TYPE(DType) \ + case DType: { \ + if constexpr (IsTypeInList, TypeMap_t...>) { \ + return std::forward(func).template operator()>(std::forward(args)...); \ + } else { \ + break; \ + } \ + } + + CASE_FOR_TYPE(DataType::kUINT8) + CASE_FOR_TYPE(DataType::kINT8) + CASE_FOR_TYPE(DataType::kUINT16) + CASE_FOR_TYPE(DataType::kINT16) + CASE_FOR_TYPE(DataType::kUINT32) + CASE_FOR_TYPE(DataType::kINT32) + CASE_FOR_TYPE(DataType::kUINT64) + CASE_FOR_TYPE(DataType::kINT64) + CASE_FOR_TYPE(DataType::kFLOAT32) + CASE_FOR_TYPE(DataType::kFLOAT64) +#ifdef USE_CUDA + CASE_FOR_TYPE(DataType::kBFLOAT16) + CASE_FOR_TYPE(DataType::kFLOAT16) +#endif +#undef CASE_FOR_TYPE + } + LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); + // prevent the compiler warning about control reaching the end of non-void function + std::abort(); +} + +namespace { +/** + * @brief Responsible for resolving a list of data types and invoking a functor with the corresponding C++ types. + * + * @tparam index Current index in the `dtypes` vector. + * @tparam AllowedListTuple Tuple of allowed `DataType` sets per dispatch level. + * @tparam ResolvedTypes Accumulated resolved C++ types. + */ +template struct DtypeDispatcher { + + /** + * @brief Dispatches based on runtime data types and invokes the functor with resolved C++ types. + * + * Recursively matches each `DataType` in `dtypes` against the corresponding allowed list in + * `AllowedListTuple`. For each match, maps the `DataType` to a C++ type using `TypeMap_t`. + * Once all types are resolved, invokes the functor. + * + * @param dtypes Vector of runtime data types to dispatch on. + * @param func Functor to invoke with resolved template types. + * @param context_identifier String used for logging or error context. + * @param args Additional arguments forwarded to the functor. + * @return Result of invoking the functor with resolved types and forwarded arguments. + */ + template + static auto call(const std::vector &dtypes, Functor &&func, std::string_view context_identifier, + Args &&...args) { + constexpr size_t num_lists = std::tuple_size_v; + + if constexpr (index == num_lists) { + // Base case: All types resolved, invoke the functor + return std::forward(func).template operator()(std::forward(args)...); + } else { + // Recursive case: Resolve the next type + using CurrentList = std::tuple_element_t; + DataType dtype = dtypes[index]; + + switch (dtype) { +#define CASE_FOR_TYPE(DType) \ + case DType: \ + if constexpr (IsDataTypeInList_v) { \ + using T = TypeMap_t; \ + return DtypeDispatcher::call( \ + dtypes, std::forward(func), context_identifier, std::forward(args)...); \ + } else { \ + break; \ + } + + CASE_FOR_TYPE(DataType::kUINT8) + CASE_FOR_TYPE(DataType::kINT8) + CASE_FOR_TYPE(DataType::kUINT16) + CASE_FOR_TYPE(DataType::kINT16) + CASE_FOR_TYPE(DataType::kUINT32) + CASE_FOR_TYPE(DataType::kINT32) + CASE_FOR_TYPE(DataType::kUINT64) + CASE_FOR_TYPE(DataType::kINT64) + CASE_FOR_TYPE(DataType::kFLOAT32) + CASE_FOR_TYPE(DataType::kFLOAT64) +#ifdef USE_CUDA + CASE_FOR_TYPE(DataType::kBFLOAT16) + CASE_FOR_TYPE(DataType::kFLOAT16) +#endif +#undef CASE_FOR_TYPE + } + LOG_UNSUPPORTED_DTYPE(dtype, context_identifier); + // prevent the compiler warning about control reaching the end of non-void function + std::abort(); + } + } +}; +} // namespace + +/** + * @brief Dispatches a functor based on a list of runtime data types. + * + * Given a vector of `DataType` values and corresponding allowed type lists, this function resolves + * each data type to its mapped C++ type using `TypeMap_t`, then invokes the provided functor with + * those types as template parameters. + * + * @tparam AllowedTypeLists Variadic list of allowed data type sets per dispatch level. + * @tparam Functor Callable object with a templated call operator. + * @tparam Args Additional arguments to forward to the functor. + * + * @param dtypes Vector of runtime data types to dispatch on. + * @param func Functor to invoke after resolving types. + * @param context_identifier Optional context string for error reporting/logging. + * @param args Additional arguments to pass to the functor. + * @return Result of invoking the functor with resolved template types and arguments. + * + * Example functor using a templated lambda: [=]() { ... } + */ +template +auto DispatchFunc(const std::vector &dtypes, Functor &&func, std::string_view context_identifier = "", + Args &&...args) { + constexpr size_t num_lists = sizeof...(AllowedTypeLists); + if (dtypes.size() != num_lists) { + LOG(FATAL) << std::format("DispatchFunc expects {} dtypes, but only got {} in {}", num_lists, dtypes.size(), + context_identifier); + std::abort(); + } + + using AllowedListTuple = std::tuple; + return DtypeDispatcher<0, AllowedListTuple>::call(dtypes, std::forward(func), context_identifier, + std::forward(args)...); +} + class KernelFunction { public: template explicit KernelFunction(FuncT &&func) : func_ptr_(reinterpret_cast(func)) {} template RetT Call(ArgsT... args) const { // =================================== 作业 =================================== - // TODO:实现通用kernel调用接口 - // 功能描述:将存储的函数指针转换为指定类型并调用 - // =================================== 作业 =================================== +#ifdef PROFILE_MODE + const auto &ctx = GetProfileContext(); + Profiler::Instance().StartRecord(ctx.name, ctx.device); +#endif using FuncT = RetT (*)(ArgsT...); - // TODO: 实现函数调用逻辑 + auto fn = reinterpret_cast(func_ptr_); + + if constexpr (std::is_void_v) { + fn(std::forward(args)...); + +#ifdef PROFILE_MODE + Profiler::Instance().EndRecord(ctx.name, ctx.device); +#endif + return; + } else { + RetT ret = fn(std::forward(args)...); + +#ifdef PROFILE_MODE + Profiler::Instance().EndRecord(ctx.name, ctx.device); +#endif + return ret; + } } private: @@ -48,6 +427,9 @@ class Dispatcher { // TODO:实现kernel注册机制 // 功能描述:将kernel函数与设备类型、名称绑定 // =================================== 作业 =================================== + CHECK(!key_to_kernel_map_.contains(key)) + << "Kernel already registered: " << key.second << " on device: " << static_cast(key.first); + key_to_kernel_map_.emplace(key, kernel); } private: @@ -56,7 +438,7 @@ class Dispatcher { } // namespace infini_train #define REGISTER_KERNEL(device, kernel_name, kernel_func) \ - // =================================== 作业 =================================== - // TODO:实现自动注册宏 - // 功能描述:在全局静态区注册kernel,避免显式初始化代码 - // =================================== 作业 =================================== + static const bool _##kernel_name##_registered##__COUNTER__ = []() { \ + infini_train::Dispatcher::Instance().Register({device, #kernel_name}, kernel_func); \ + return true; \ + }(); diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index a6479de..cebe411 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -10,6 +10,7 @@ #include "Eigen/Dense" #include "glog/logging.h" +#include "infini_train/include/datatype.h" #include "infini_train/include/device.h" namespace infini_train { @@ -33,21 +34,6 @@ struct PrintOptions { }; } // namespace -enum class DataType : int8_t { - kUINT8, - kINT8, - kUINT16, - kINT16, - kUINT32, - kINT32, - kUINT64, - kINT64, - kBFLOAT16, - kFLOAT16, - kFLOAT32, - kFLOAT64, -}; - class TensorBuffer { public: TensorBuffer(Device device, size_t size); diff --git a/infini_train/src/autograd/elementwise.cc b/infini_train/src/autograd/elementwise.cc index 5a790a5..9c78fbe 100644 --- a/infini_train/src/autograd/elementwise.cc +++ b/infini_train/src/autograd/elementwise.cc @@ -7,20 +7,22 @@ namespace infini_train::autograd { std::vector> Neg::Forward(const std::vector> &input_tensors) { // =================================== 作业 =================================== - // TODO:通过Dispatcher获取设备专属kernel,对输入张量进行取反操作 - // NOTES: 依赖test_dispatcher,Neg kernel实现已给出 - // =================================== 作业 =================================== + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; - return std::vector>(); + auto device = input->GetDevice().Type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "NegForward"}); + return {kernel.Call>(input)}; } std::vector> Neg::Backward(const std::vector> &grad_outputs) { // =================================== 作业 =================================== - // TODO:通过Dispatcher获取设备专属的反向传播kernel,计算梯度 - // NOTES: 依赖test_dispatcher,Neg的kernel实现已给出 - // =================================== 作业 =================================== + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; - return std::vector>(); + auto device = grad_output->GetDevice().Type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "NegBackward"}); + return {kernel.Call>(grad_output)}; } std::vector> Reciprocal::Forward(const std::vector> &input_tensors) { diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index cc6842f..a1f1018 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -2,6 +2,13 @@ #include +#ifdef USE_CUDA +#include "infini_train/include/common/cuda/common_cuda.cuh" +#endif +#ifdef USE_NCCL +#include "nccl.h" +#include +#endif #include "glog/logging.h" namespace infini_train { @@ -38,4 +45,95 @@ std::ostream &operator<<(std::ostream &os, const Device &device) { return os; } -} // namespace infini_train +CpuDevice::CpuDevice() : Device(DeviceType::kCPU, 0) {} + +#ifdef USE_CUDA +CudaDevice::~CudaDevice() { + if (stream_ != nullptr) { + cudaStreamDestroy(stream_); + } + + if (cublas_handle_ != nullptr) { + cublasDestroy(cublas_handle_); + } +} + +void CudaDevice::SetDevice() const { cudaSetDevice(index_); } +void CudaDevice::Synchronize() const { cudaDeviceSynchronize(); } + +cudaStream_t CudaDevice::Stream() const { return stream_; } + +cublasHandle_t CudaDevice::CublasHandle() const { return cublas_handle_; } + +#ifdef USE_NCCL +ncclComm_t CudaDevice::NcclComm() const { return nccl_comm_; } +#endif + +CudaDevice::CudaDevice(int8_t index) : Device(DeviceType::kCUDA, index) { + // TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode + SetDevice(); + cudaStreamCreate(&stream_); + + cublasCreate(&cublas_handle_); + cublasSetStream(cublas_handle_, stream_); +} +#endif // USE_CUDA + +const DeviceManager *DeviceManager::Instance() { + static auto instance = std::unique_ptr(new DeviceManager()); +#ifdef USE_NCCL + static std::once_flag flag; + std::call_once(flag, [&]() { instance->InitNcclCommunicators(); }); +#endif + return instance.get(); +} + +const Device *DeviceManager::GetDevice(DeviceType type, int8_t index) const { + return devices_map_.at(type).at(index).get(); +} + +const Device *DeviceManager::GetDefaultDevice() const { return devices_map_.at(DeviceType::kCPU).at(0).get(); } + +std::vector DeviceManager::GetAllAvailableDevices(DeviceType device_type) const { + std::vector devices; + for (const auto &device : devices_map_.at(device_type)) { devices.push_back(device.get()); } + return devices; +} + +DeviceManager::DeviceManager() { + devices_map_[DeviceType::kCPU].push_back(std::unique_ptr(new CpuDevice())); +#ifdef USE_CUDA + CUDA_DRIVER_CHECK(cuInit(0)); + int device_count = 0; + CUDA_DRIVER_CHECK(cuDeviceGetCount(&device_count)); + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + for (int idx = 0; idx < device_count; ++idx) { + devices_map_[DeviceType::kCUDA].push_back(std::unique_ptr(new CudaDevice(idx))); + } + CUDA_CHECK(cudaSetDevice(current_device)); +#endif +} + +#ifdef USE_NCCL +void DeviceManager::InitNcclCommunicators() { + const auto &cuda_devices = devices_map_.at(DeviceType::kCUDA); + int num_devices = cuda_devices.size(); + + std::vector device_indices; + std::vector streams; + for (const auto &device : cuda_devices) { + const auto *cuda_device = dynamic_cast(device.get()); + device_indices.push_back(cuda_device->Index()); + } + + std::vector nccl_comms(num_devices, nullptr); + NCCL_CHECK(ncclCommInitAll(nccl_comms.data(), num_devices, device_indices.data())); + + for (int i = 0; i < num_devices; ++i) { + auto *device = dynamic_cast(cuda_devices[i].get()); + device->nccl_comm_ = nccl_comms[i]; + } +} +#endif +}// namespace infini_train diff --git a/infini_train/src/kernels/cpu/accumulate_grad.cc b/infini_train/src/kernels/cpu/accumulate_grad.cc index 55637cd..62e6f0f 100644 --- a/infini_train/src/kernels/cpu/accumulate_grad.cc +++ b/infini_train/src/kernels/cpu/accumulate_grad.cc @@ -15,9 +15,24 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p const std::shared_ptr &m, const std::shared_ptr &v, float learning_rate, float beta1, float beta2, float eps, int64_t t) { // =================================== 作业 =================================== - // TODO:实现Adam优化器的梯度累积和参数更新 - // REF: - // =================================== 作业 =================================== + const float *grad_data = static_cast(grad->DataPtr()); + float *m_data = static_cast(m->DataPtr()); + float *v_data = static_cast(v->DataPtr()); + float *param_data = static_cast(param->DataPtr()); + + const float bias_correction_m = 1.0f - std::pow(beta1, t); + const float bias_correction_v = 1.0f - std::pow(beta2, t); + +#pragma omp parallel for + for (size_t idx = 0; idx < grad->NumElements(); ++idx) { + m_data[idx] = beta1 * m_data[idx] + (1 - beta1) * grad_data[idx]; + v_data[idx] = beta2 * v_data[idx] + (1 - beta2) * grad_data[idx] * grad_data[idx]; + + const float m_hat = m_data[idx] / bias_correction_m; + const float v_hat = v_data[idx] / bias_correction_v; + + param_data[idx] -= learning_rate * m_hat / (std::sqrt(v_hat) + eps); + } } } // namespace infini_train::kernels::cpu diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index 140e756..a0554dc 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -12,11 +12,43 @@ namespace infini_train::kernels::cpu { std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { // =================================== 作业 =================================== - // TODO:实现CPU上的矩阵乘法前向计算 - // REF: - // =================================== 作业 =================================== + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + // TODO(dcj): support broadcast later + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, DataType::kFLOAT32); - auto output = std::make_shared(); + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + float acc = 0.0f; + for (int64_t p = 0; p < k; ++p) { + acc += static_cast(input->DataPtr())[b * m * k + i * k + p] + * static_cast(other->DataPtr())[b * k * n + p * n + j]; + } + static_cast(output->DataPtr())[b * m * n + i * n + j] = acc; + } + } + } return {output}; } @@ -24,12 +56,51 @@ std::tuple, std::shared_ptr> MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, const std::shared_ptr &grad_output) { // =================================== 作业 =================================== - // TODO:实现CPU上的矩阵乘法反向传播 - // REF: - // =================================== 作业 =================================== + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); - auto grad_input = std::make_shared(); - auto grad_other = std::make_shared(); + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + } + + auto grad_input = std::make_shared(input_dims, DataType::kFLOAT32); + auto grad_other = std::make_shared(other_dims, DataType::kFLOAT32); + grad_input->Fill(0.0f); + grad_other->Fill(0.0f); + + for (int64_t b = 0; b < bs; ++b) { + for (int64_t i = 0; i < m; ++i) { + for (int64_t j = 0; j < n; ++j) { + const float grad = static_cast(grad_output->DataPtr())[b * m * n + i * n + j]; + for (int64_t p = 0; p < k; ++p) { + const auto input_idx = b * m * k + i * k + p; + const auto other_idx = b * k * n + p * n + j; + static_cast(grad_input->DataPtr())[input_idx] + += grad * static_cast(other->DataPtr())[other_idx]; + static_cast(grad_other->DataPtr())[other_idx] + += grad * static_cast(input->DataPtr())[input_idx]; + } + } + } + } return {grad_input, grad_other}; } diff --git a/infini_train/src/kernels/cuda/accumulate_grad.cu b/infini_train/src/kernels/cuda/accumulate_grad.cu index 5f977c3..2680d54 100644 --- a/infini_train/src/kernels/cuda/accumulate_grad.cu +++ b/infini_train/src/kernels/cuda/accumulate_grad.cu @@ -1,5 +1,4 @@ -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" +#include "infini_train/include/common/cuda/common_cuda.cuh" namespace infini_train::kernels::cuda { @@ -22,6 +21,25 @@ void AccumulateGrad(const std::shared_ptr &gradient, float rate, const s AccumulateGradKernel<<>>(grad_ptr, rate, tensor_ptr, num_elements); } +template +__global__ void AdamAccumulateGradKernel(const T *grad_data, T *param_data, size_t num_elements, T *m_data, T *v_data, + float learning_rate, float beta1, float beta2, float eps, + const float bias_correction_m, const float bias_correction_v) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num_elements) { + m_data[idx] = common::cuda::Fma(common::cuda::Cast(beta1), m_data[idx], + common::cuda::Cast(1 - beta1) * grad_data[idx]); + v_data[idx] = common::cuda::Fma(common::cuda::Cast(beta2), v_data[idx], + common::cuda::Cast(1 - beta2) * grad_data[idx] * grad_data[idx]); + + const float m_hat = common::cuda::Cast(m_data[idx]) / bias_correction_m; + const float v_hat = common::cuda::Cast(v_data[idx]) / bias_correction_v; + + param_data[idx] = common::cuda::Sub( + param_data[idx], common::cuda::Cast(learning_rate * m_hat * __frcp_rn(__fsqrt_rn(v_hat) + eps))); + } +} + void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_ptr ¶m, const std::shared_ptr &m, const std::shared_ptr &v, float learning_rate, float beta1, float beta2, float eps, int64_t t) { @@ -29,6 +47,25 @@ void AdamAccumulateGrad(const std::shared_ptr &grad, const std::shared_p // TODO:实现Adam优化器的梯度累积和参数更新 // REF: // =================================== 作业 =================================== + size_t num_elements = grad->NumElements(); + + const float bias_correction_m = 1.0f - std::pow(beta1, t); + const float bias_correction_v = 1.0f - std::pow(beta2, t); + + int threads_per_block = 256; + int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, grad->GetDevice().Index())); + + DispatchFunc( + grad->Dtype(), + [=]() { + AdamAccumulateGradKernel<<Stream()>>>( + static_cast(grad->DataPtr()), static_cast(param->DataPtr()), num_elements, + static_cast(m->DataPtr()), static_cast(v->DataPtr()), learning_rate, beta1, beta2, eps, + bias_correction_m, bias_correction_v); + }, + "CUDA AdamAccumulateGrad"); } } // namespace infini_train::kernels::cuda diff --git a/infini_train/src/kernels/cuda/linear.cu b/infini_train/src/kernels/cuda/linear.cu index efaaaa6..4ea54ac 100644 --- a/infini_train/src/kernels/cuda/linear.cu +++ b/infini_train/src/kernels/cuda/linear.cu @@ -2,6 +2,7 @@ #include "glog/logging.h" #include +#include "infini_train/include/common/cuda/common_cuda.cuh" #include "infini_train/include/dispatcher.h" #include "infini_train/include/tensor.h" @@ -25,11 +26,65 @@ namespace infini_train::kernels::cuda { std::shared_ptr MatmulForward(const std::shared_ptr &input, const std::shared_ptr &other) { // =================================== 作业 =================================== - // TODO:实现CUDA上的矩阵乘法前向计算 - // REF: - // =================================== 作业 =================================== + /* + output[*, m, n] = input[*, m, k] * other[*, k, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_GE(other_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + const int64_t n = other_dims[other_dims.size() - 1]; + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + } + + auto dtype = input->Dtype(); + std::vector output_dims = input_dims; + output_dims[output_dims.size() - 1] = n; + auto output = std::make_shared(output_dims, dtype, input->GetDevice()); + + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, input->GetDevice().Index())); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = cuda_device->CublasHandle(); + + // cuBLAS is colmun-major + // output = input * other --> output.T = other.T * input.T + // C = A * B ==> output.T[*, n, m] = other.T[*, n, k] * input.T[*, k, m] + // C = output.T[*, n, m] + // A = other.T[*, n, k] + // B = input.T[*, k, m] + int lda = n; + int ldb = k; + int ldc = n; + int64_t stride_a = n * k; + int64_t stride_b = k * m; + int64_t stride_c = m * n; + // NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere) + + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_32F, + ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_16BF, lda, + stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, output->DataPtr(), CUDA_R_16BF, + ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + default: + LOG_UNSUPPORTED_DTYPE(dtype, "CUDA MatmulForward"); + } - auto output = std::make_shared(); return output; } @@ -37,12 +92,99 @@ std::tuple, std::shared_ptr> MatmulBackward(const std::shared_ptr &input, const std::shared_ptr &other, const std::shared_ptr &grad_output) { // =================================== 作业 =================================== - // TODO:实现CUDA上的矩阵乘法反向传播 - // REF: - // =================================== 作业 =================================== + /* + grad_input[*, m, k] = grad_output[*, m, n] * other[*, k, n]^T + grad_other[*, k, n] = input[*, m, k]^T * grad_output[*, m, n] + */ + const auto &input_dims = input->Dims(); + const auto &other_dims = other->Dims(); + const auto &grad_output_dims = grad_output->Dims(); + + CHECK_GE(input_dims.size(), 2); + CHECK_EQ(input_dims.size(), other_dims.size()); + CHECK_EQ(input_dims.size(), grad_output_dims.size()); + + const int64_t m = input_dims[input_dims.size() - 2]; + const int64_t k = input_dims[input_dims.size() - 1]; + const int64_t n = other_dims[other_dims.size() - 1]; + CHECK_EQ(k, other_dims[other_dims.size() - 2]); + CHECK_EQ(m, grad_output_dims[grad_output_dims.size() - 2]); + CHECK_EQ(n, grad_output_dims[grad_output_dims.size() - 1]); + + const int64_t bs = std::accumulate(input_dims.rbegin() + 2, input_dims.rend(), 1, std::multiplies{}); + for (int64_t i = 0; i < input_dims.size() - 2; ++i) { + CHECK_EQ(input_dims[i], other_dims[i]) << "Batch dims must match"; + CHECK_EQ(input_dims[i], grad_output_dims[i]) << "Batch dims must match"; + } + + auto dtype = input->Dtype(); + auto grad_input = std::make_shared(input_dims, dtype, grad_output->GetDevice()); + auto grad_other = std::make_shared(other_dims, dtype, grad_output->GetDevice()); + + DispatchFunc( + dtype, + [=]() { + grad_input->Fill(0); + grad_other->Fill(0); + }, + "CUDA MatmulBackward"); + + const auto *cuda_device = dynamic_cast( + DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, input->GetDevice().Index())); + const float alpha = 1.0f, beta = 0.0f; + cublasHandle_t handle = cuda_device->CublasHandle(); + + { + // cuBLAS is colmun-major + // grad_input = grad_output * other.T --> grad_input.T = other * grad_output.T + // C = A.T * B ==> grad_input.T[*, k, m] = other[*, k, n] * grad_output.T[*, n, m] + // C = grad_input.T[*, k, m] + // A = other.T[*, n, k] + // B = grad_output.T[*, n, m] + const int lda = n, ldb = n, ldc = k; + const int64_t stride_a = k * n; + const int64_t stride_b = n * m; + const int64_t stride_c = m * k; + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda, + stride_a, grad_output->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_input->DataPtr(), + CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE( + WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_16BF, lda, stride_a, + grad_output->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_input->DataPtr(), CUDA_R_16BF, ldc, + stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + } + + { + // cuBLAS is colmun-major + // grad_other = input.T * grad_output --> grad_other.T = grad_output.T * input + // C = A * B.T ==> grad_other.T[*, n, k] = grad_output.T[*, n, m] * input[*, m, k] + // C = grad_other.T[*, n, k] + // A = grad_output.T[*, n, m] + // B = input.T[*, k, m] + const int lda = n, ldb = k, ldc = n; + const int64_t stride_a = n * m; + const int64_t stride_b = k * m; + const int64_t stride_c = n * k; + switch (dtype) { + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output->DataPtr(), CUDA_R_32F, + lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, grad_other->DataPtr(), + CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kFLOAT32) + DISPATCH_CASE(WRAP(CUBLAS_CHECK(cublasGemmStridedBatchedEx( + handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output->DataPtr(), CUDA_R_16BF, + lda, stride_a, input->DataPtr(), CUDA_R_16BF, ldb, stride_b, &beta, grad_other->DataPtr(), + CUDA_R_16BF, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));), + DataType::kBFLOAT16) + } + } - auto grad_input = std::make_shared(); - auto grad_other = std::make_shared(); return {grad_input, grad_other}; } diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 8f8c744..ca36d08 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -25,36 +25,6 @@ #include "infini_train/include/nn/init.h" namespace infini_train { -namespace { -const std::unordered_map kDataTypeToSize = { - {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, - {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, - {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, -}; - -const std::unordered_map kDataTypeToDesc = { - {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, - {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, - {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, - {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, -}; - -template struct TypeMap; - -template <> struct TypeMap { - using type = float; -}; -template <> struct TypeMap { - using type = double; -}; -template <> struct TypeMap { - using type = int32_t; -}; -template <> struct TypeMap { - using type = int64_t; -}; -} // namespace - TensorBuffer::TensorBuffer(Device device, size_t size) : device_(device), size_(size) { switch (device_.Type()) { case DeviceType::kCPU: @@ -122,44 +92,36 @@ size_t Tensor::NumElements() const { return num_elements_; } DataType Tensor::Dtype() const { return dtype_; } template void Tensor::Fill(T value) { + auto device = GetDevice(); + device.SetDevice(); + DataType dtype = Dtype(); uint64_t storage = 0; - switch (dtype) { - case DataType::kFLOAT32: { - using TargetT = typename TypeMap::type; - TargetT casted_value = static_cast(value); - std::memcpy(&storage, &casted_value, sizeof(TargetT)); - break; - } - case DataType::kFLOAT64: { - using TargetT = typename TypeMap::type; - TargetT casted_value = static_cast(value); - std::memcpy(&storage, &casted_value, sizeof(TargetT)); - break; - } - case DataType::kINT32: { - using TargetT = typename TypeMap::type; - TargetT casted_value = static_cast(value); - std::memcpy(&storage, &casted_value, sizeof(TargetT)); - break; - } - case DataType::kINT64: { - using TargetT = typename TypeMap::type; + DispatchFunc(Dtype(), [&storage, value]() { TargetT casted_value = static_cast(value); - std::memcpy(&storage, &casted_value, sizeof(TargetT)); - break; - } - default: - throw std::runtime_error("Unsupported data type in Tensor::Fill()"); - } + std::memcpy((void *)(&storage), &casted_value, sizeof(TargetT)); + }); - auto kernel = Dispatcher::Instance().GetKernel({GetDevice().Type(), "Fill"}); + auto kernel = Dispatcher::Instance().GetKernel({device.Type(), "Fill"}); kernel.Call(shared_from_this(), static_cast(&storage)); } +template void Tensor::Fill(uint8_t); +template void Tensor::Fill(int8_t); +template void Tensor::Fill(uint16_t); +template void Tensor::Fill(int16_t); +template void Tensor::Fill(uint32_t); +template void Tensor::Fill(int32_t); +template void Tensor::Fill(uint64_t); +template void Tensor::Fill(int64_t); template void Tensor::Fill(float); +template void Tensor::Fill(double); +#ifdef USE_CUDA +template void Tensor::Fill(nv_bfloat16); +template void Tensor::Fill(half); +#endif Eigen::Map> Tensor::EigenMatrix() { const int64_t bs = std::accumulate(dims_.rbegin() + 1, dims_.rend(), 1, std::multiplies()); @@ -279,11 +241,25 @@ std::shared_ptr Tensor::Contiguous() { std::shared_ptr Tensor::Flatten(int64_t start, int64_t end) { // return Contiguous()->View(new_shape); // =================================== 作业 =================================== - // TODO:实现张量扁平化操作,将指定维度范围[start, end]内的所有维度合并为一个维度 - // HINT: - // =================================== 作业 =================================== + auto ndim = dims_.size(); + auto start_dim = start >= 0 ? start : start + ndim; + auto end_dim = end >= 0 ? end : end + ndim; + CHECK(start_dim >= 0 && end_dim >= start_dim && end_dim <= ndim); + + std::vector new_shape; + int64_t flatten_size = 1; + for (int64_t i = 0; i < ndim; ++i) { + if (i < start_dim || i > end_dim) { + new_shape.push_back(dims_[i]); + } else { + flatten_size *= dims_[i]; + if (i == end_dim) { + new_shape.push_back(flatten_size); + } + } + } - return std::make_shared(); + return Contiguous()->View(new_shape); } std::shared_ptr Tensor::Squeeze(int64_t dim) { @@ -355,9 +331,20 @@ std::shared_ptr Tensor::RequiresGrad() { void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool create_graph) const { // =================================== 作业 =================================== - // TODO:实现自动微分反向传播 - // 功能描述:1. 计算当前张量对叶子节点的梯度 2. 支持多输出场景的梯度累加 - // =================================== 作业 =================================== + CHECK(!retain_graph && !create_graph) << "Not implemented yet!"; + if (grad_fn_) { + if (!gradient) { + CHECK_EQ(dims_.size(), 0); + gradient = std::make_shared(std::vector{}, dtype_, GetDevice()); + gradient->Fill(1.0f); + } else { + CHECK_EQ(static_cast(GetDevice().Type()), static_cast(gradient->GetDevice().Type())); + CHECK_EQ(static_cast(dtype_), static_cast(gradient->Dtype())); + CHECK_EQ(dims_.size(), gradient->Dims().size()); + for (int idx = 0; idx < dims_.size(); ++idx) { CHECK_EQ(dims_[idx], gradient->Dims()[idx]); } + } + grad_fn_->BackwardPartial(gradient, output_idx_); + } } void Tensor::ZeroGrad() { diff --git a/test/optimizer/test_adam.cc b/test/optimizer/test_adam.cc index 7eb83bb..286ea62 100644 --- a/test/optimizer/test_adam.cc +++ b/test/optimizer/test_adam.cc @@ -67,16 +67,15 @@ TEST(AdamOptimizerTest, BasicParameterUpdateCuda) { param->Fill(1.0f); // 初始参数值 [1.0, 1.0, 1.0] param->RequiresGrad(); - auto grad = std::make_shared(param->Dims(), param->Dtype()); - grad->Fill(1.0f); - float* grad_data = static_cast(param->grad()->DataPtr()); - std::memcpy(grad_data, grad->DataPtr(), grad->SizeInBytes()); + // 直接设置梯度,不需要额外的grad tensor + param->grad()->Fill(1.0f); optimizers::Adam optimizer({param}, 0.001f, 0.9f, 0.999f, 1e-8); optimizer.Step(); - float* param_data = static_cast(param->DataPtr()); + auto param_cpu = param->To(Device(DeviceType::kCPU, 0)); + float* param_data = static_cast(param_cpu.DataPtr()); for (int i = 0; i < 3; ++i) { EXPECT_LT(param_data[i], 1.0f); // 参数值应该减小 } @@ -96,7 +95,8 @@ TEST(AdamOptimizerTest, MomentumAccumulationCuda) { std::vector param_history; for (int i = 0; i < 3; ++i) { optimizer.Step(); - param_history.push_back(static_cast(param->DataPtr())[0]); + auto param_cpu = param->To(Device(DeviceType::kCPU, 0)); + param_history.push_back(static_cast(param_cpu.DataPtr())[0]); } EXPECT_LT(param_history[1], param_history[0]); diff --git a/test/optimizer/test_adam_cuda.cc b/test/optimizer/test_adam_cuda.cc index 0a6ccda..0130c40 100644 --- a/test/optimizer/test_adam_cuda.cc +++ b/test/optimizer/test_adam_cuda.cc @@ -15,16 +15,15 @@ TEST(AdamOptimizerTest, BasicParameterUpdateCuda) { param->Fill(1.0f); // 初始参数值 [1.0, 1.0, 1.0] param->RequiresGrad(); - auto grad = std::make_shared(param->Dims(), param->Dtype()); - grad->Fill(1.0f); - float* grad_data = static_cast(param->grad()->DataPtr()); - std::memcpy(grad_data, grad->DataPtr(), grad->SizeInBytes()); + // 直接设置梯度,不需要额外的grad tensor + param->grad()->Fill(1.0f); optimizers::Adam optimizer({param}, 0.001f, 0.9f, 0.999f, 1e-8); optimizer.Step(); - float* param_data = static_cast(param->DataPtr()); + auto param_cpu = param->To(Device(DeviceType::kCPU, 0)); + float* param_data = static_cast(param_cpu.DataPtr()); for (int i = 0; i < 3; ++i) { EXPECT_LT(param_data[i], 1.0f); // 参数值应该减小 } @@ -44,7 +43,8 @@ TEST(AdamOptimizerTest, MomentumAccumulationCuda) { std::vector param_history; for (int i = 0; i < 3; ++i) { optimizer.Step(); - param_history.push_back(static_cast(param->DataPtr())[0]); + auto param_cpu = param->To(Device(DeviceType::kCPU, 0)); + param_history.push_back(static_cast(param_cpu.DataPtr())[0]); } EXPECT_LT(param_history[1], param_history[0]); diff --git a/third_party/eigen b/third_party/eigen index 2cf66d4..1e65707 160000 --- a/third_party/eigen +++ b/third_party/eigen @@ -1 +1 @@ -Subproject commit 2cf66d4b0d0ba52cbf2507e15998c4735aa14406 +Subproject commit 1e65707aa20603fc2ee9c2ac21c466ef57d23e10 diff --git a/third_party/googletest b/third_party/googletest index 309dab8..32f9f4c 160000 --- a/third_party/googletest +++ b/third_party/googletest @@ -1 +1 @@ -Subproject commit 309dab8d4bbfcef0ef428762c6fec7172749de0f +Subproject commit 32f9f4c82afa4249af66b55278df15c16b3031ea