forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopUtils.cpp
More file actions
357 lines (323 loc) · 11.8 KB
/
opUtils.cpp
File metadata and controls
357 lines (323 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/runtime/utils/mpiTags.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "cuda.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <functional>
#include <mutex>
#include <thread>
#if ENABLE_MULTI_DEVICE
std::unordered_map<nvinfer1::DataType, ncclDataType_t>* getDtypeMap()
{
static std::unordered_map<nvinfer1::DataType, ncclDataType_t> dtypeMap = {
{nvinfer1::DataType::kFLOAT, ncclFloat32},
{nvinfer1::DataType::kHALF, ncclFloat16},
{nvinfer1::DataType::kBF16, ncclBfloat16},
{nvinfer1::DataType::kFP8, ncclInt8},
{nvinfer1::DataType::kBOOL, ncclInt8},
{nvinfer1::DataType::kINT32, ncclInt32},
{nvinfer1::DataType::kINT64, ncclInt64},
{nvinfer1::DataType::kUINT8, ncclUint8},
{nvinfer1::DataType::kINT8, ncclInt8},
};
return &dtypeMap;
}
namespace
{
// Get NCCL unique ID for a group of ranks.
ncclUniqueId getUniqueId(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
ncclUniqueId id;
if (rank == *group.begin())
{
NCCLCHECK_THROW(ncclGetUniqueId(&id));
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
{
COMM_SESSION.sendValue(id, *it, tensorrt_llm::mpi::MpiTag::kDefault);
}
}
else
{
COMM_SESSION.recvValue(id, *group.begin(), tensorrt_llm::mpi::MpiTag::kDefault);
}
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return id;
}
} // namespace
std::shared_ptr<ncclComm_t> getComm(std::set<int> const& group)
{
auto const rank = COMM_SESSION.getRank();
TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank);
static std::map<std::set<int>, std::shared_ptr<ncclComm_t>> commMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::ostringstream oss;
int index = 0;
for (auto const& rank : group)
{
if (index != 0)
{
oss << ",";
}
oss << rank;
index++;
}
auto groupStr = oss.str();
auto it = commMap.find(group);
if (it != commMap.end())
{
auto ncclComm = it->second;
TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank);
return ncclComm;
}
TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank);
ncclUniqueId id = getUniqueId(group);
int groupRank = 0;
for (auto const& currentRank : group)
{
if (rank == currentRank)
break;
++groupRank;
}
TLLM_CHECK(static_cast<size_t>(groupRank) < group.size());
std::shared_ptr<ncclComm_t> ncclComm(new ncclComm_t,
[](ncclComm_t* comm)
{
ncclCommDestroy(*comm);
delete comm;
});
#if defined(_WIN32)
// Need static connection initialization for accurate KV cache size estimation
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr)
_putenv_s("NCCL_RUNTIME_CONNECT", "0");
// Disable graph register to avoid startup hangs
if (getenv("NCCL_GRAPH_REGISTER") == nullptr)
_putenv_s("NCCL_GRAPH_REGISTER", "0");
#else
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
setenv("NCCL_GRAPH_REGISTER", "0", 0);
#endif // _WIN32
NCCLCHECK_THROW(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank));
commMap[group] = ncclComm;
TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank);
return ncclComm;
}
#endif // ENABLE_MULTI_DEVICE
void const* tensorrt_llm::common::op::getCommSessionHandle()
{
#if ENABLE_MULTI_DEVICE
return &COMM_SESSION;
#else
return nullptr;
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
using tensorrt_llm::common::op::hash;
// Get current cuda context, a default context will be created if there is no context.
inline CUcontext getCurrentCudaCtx()
{
CUcontext ctx{};
CUresult err = cuCtxGetCurrent(&ctx);
if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr)
{
TLLM_CUDA_CHECK(cudaFree(nullptr));
err = cuCtxGetCurrent(&ctx);
}
TLLM_CHECK(err == CUDA_SUCCESS);
return ctx;
}
// Helper to create per-cuda-context and per-thread singleton managed by std::shared_ptr.
// Unlike conventional singletons, singleton created with this will be released
// when not needed, instead of on process exit.
// Objects of this class shall always be declared static / global, and shall never own CUDA
// resources.
template <typename T>
class PerCudaCtxPerThreadSingletonCreator
{
public:
using CreatorFunc = std::function<std::unique_ptr<T>()>;
using DeleterFunc = std::function<void(T*)>;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerCudaCtxPerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
, mObservers{new std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>()}
{
}
~PerCudaCtxPerThreadSingletonCreator()
{
std::lock_guard<std::mutex> lk{mMutex};
delete mObservers;
mObservers = nullptr;
}
std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
CUcontext ctx{getCurrentCudaCtx()};
std::thread::id thread = std::this_thread::get_id();
auto const key = std::make_tuple(ctx, thread);
std::shared_ptr<T> result = (*mObservers)[key].lock();
if (result == nullptr)
{
TLLM_LOG_TRACE("creating singleton instance for CUDA context %lu and thread %lu", ctx, thread);
// Create the resource and register with an observer.
result = std::shared_ptr<T>{mCreator().release(),
[this, key](T* obj)
{
if (obj == nullptr)
{
return;
}
mDeleter(obj);
if (mObservers == nullptr)
{
return;
}
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
std::lock_guard<std::mutex> lk{mMutex};
// Must check observer again because another thread may created new instance for this ctx and this
// thread just before we lock mMutex. We can't infer that the observer is stale from the fact that
// obj is destroyed, because shared_ptr ref-count checking and observer removing are not in one
// atomic operation, and the observer may be changed to observe another instance.
auto it = mObservers->find(key);
if (it == mObservers->end())
{
return;
}
observedObjHolder = it->second.lock();
if (observedObjHolder == nullptr)
{
mObservers->erase(it);
}
}};
(*mObservers)[key] = result;
}
else
{
TLLM_LOG_TRACE("singleton instance for CUDA context %d and thread %d is cached", ctx, thread);
}
return result;
}
private:
CreatorFunc mCreator;
DeleterFunc mDeleter;
mutable std::mutex mMutex;
// CUDA resources are per-context and per-thread.
using CacheKey = std::tuple<CUcontext, std::thread::id>;
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>* mObservers;
};
// Structure to hold memory information
struct MemoryInfo
{
size_t free_mb;
size_t total_mb;
float free_percent;
};
// Helper function to get current memory information
MemoryInfo getMemoryInfo()
{
size_t free_mem = 0, total_mem = 0;
TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem));
size_t const free_mb = free_mem / (1024 * 1024);
size_t const total_mb = total_mem / (1024 * 1024);
float const free_percent = (total_mem > 0) ? (static_cast<float>(free_mem) / total_mem * 100.0f) : 0.0f;
return {free_mb, total_mb, free_percent};
}
// Helper function to log current memory usage
void logMemoryUsage(char const* operation, CUcontext ctx)
{
auto const mem = getMemoryInfo();
TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb,
mem.free_percent, mem.total_mb);
}
// Helper function to throw
void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status)
{
auto const mem = getMemoryInfo();
TLLM_THROW(
"Failed to create %s. "
"Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. "
"Consider reducing kv_cache_config.free_gpu_memory_fraction.",
operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb);
}
} // namespace
std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerCudaCtxPerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublas handle", ctx);
auto handle = std::make_unique<cublasHandle_t>();
auto status = cublasCreate(handle.get());
if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublas handle", ctx, status);
}
return handle;
},
[](cublasHandle_t* handle)
{
auto status = cublasDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status);
}
delete handle;
handle = nullptr;
});
return creator();
}
std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
{
static PerCudaCtxPerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublasLt handle", ctx);
auto handle = std::make_unique<cublasLtHandle_t>();
auto status = cublasLtCreate(handle.get());
if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublasLt handle", ctx, status);
}
return handle;
},
[](cublasLtHandle_t* handle)
{
auto status = cublasLtDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status);
}
delete handle;
handle = nullptr;
});
return creator();
}