Skip to content

Commit 8ada1ea

Browse files
committed
fix: fix guided decoding state corruption in turbomind when tp>1
1 parent 8258be5 commit 8ada1ea

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

src/turbomind/layers/BaseDynamicDecodeLayer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BaseDynamicDecodeLayer {
3131
int vocab_size_padded;
3232
cudaStream_t stream;
3333
const cudaDeviceProp* device_prop;
34+
int tp_rank;
3435
};
3536

3637
virtual ~BaseDynamicDecodeLayer() = default;
@@ -42,6 +43,7 @@ class BaseDynamicDecodeLayer {
4243
vocab_size_padded_ = param.vocab_size_padded;
4344
stream_ = param.stream;
4445
device_prop_ = param.device_prop;
46+
tp_rank_ = param.tp_rank;
4547
};
4648

4749
virtual void Setup(const std::vector<const Request*>& rs, const TensorMap& args) = 0;
@@ -54,6 +56,7 @@ class BaseDynamicDecodeLayer {
5456
int vocab_size_padded_;
5557
cudaStream_t stream_;
5658
const cudaDeviceProp* device_prop_;
59+
int tp_rank_;
5760
};
5861

5962
} // namespace turbomind

src/turbomind/layers/DynamicDecodeLayer.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,22 @@ DynamicDecodeLayer::DynamicDecodeLayer(DataType dtype,
3131
int vocab_size,
3232
int vocab_size_padded,
3333
cudaStream_t stream,
34-
const cudaDeviceProp* device_prop)
34+
const cudaDeviceProp* device_prop,
35+
int tp_rank):
36+
tp_rank_(tp_rank)
3537
{
3638
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
3739
TM_CHECK(dtype == kFloat32);
38-
BaseDynamicDecodeLayer::BaseParam param{max_batch_size, vocab_size, vocab_size_padded, stream, device_prop};
40+
BaseDynamicDecodeLayer::BaseParam param{
41+
max_batch_size, vocab_size, vocab_size_padded, stream, device_prop, tp_rank};
3942
layers_.emplace_back(new LogitsProcessorLayer<float>{param});
40-
layers_.emplace_back(new GuidedDecodeMaskLayer<float>{param});
43+
if (tp_rank == 0) {
44+
layers_.emplace_back(new GuidedDecodeMaskLayer<float>{param});
45+
}
4146
layers_.emplace_back(new SamplingLayer<float>{param});
42-
layers_.emplace_back(new GuidedDecodeUpdateLayer<float>{param});
47+
if (tp_rank == 0) {
48+
layers_.emplace_back(new GuidedDecodeUpdateLayer<float>{param});
49+
}
4350
layers_.emplace_back(new StopCriteriaLayer<float>{param});
4451
}
4552

@@ -48,6 +55,7 @@ DynamicDecodeLayer::~DynamicDecodeLayer() {}
4855
void DynamicDecodeLayer::Setup(const std::vector<const Request*>& rs, const TensorMap& args)
4956
{
5057
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
58+
TM_LOG_INFO("[Dynamic Decode] Setup layer for %d", tp_rank_);
5159
for (const auto& layer : layers_) {
5260
layer->Setup(rs, args);
5361
}
@@ -82,6 +90,7 @@ void DynamicDecodeLayer::Forward(TensorMap& args)
8290
* \param sampled_nums [batch_size, 1], optional
8391
*/
8492

93+
TM_LOG_INFO("[Dynamic Decode] Forward for %d", tp_rank_);
8594
for (const auto& layer : layers_) {
8695
layer->Forward(args);
8796
}

src/turbomind/layers/DynamicDecodeLayer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class DynamicDecodeLayer {
3333
int vocab_size,
3434
int vocab_size_padded,
3535
cudaStream_t stream,
36-
const cudaDeviceProp* device_prop);
36+
const cudaDeviceProp* device_prop,
37+
int tp_rank);
3738

3839
~DynamicDecodeLayer();
3940

@@ -42,6 +43,7 @@ class DynamicDecodeLayer {
4243
void Forward(TensorMap& args);
4344

4445
private:
46+
int tp_rank_;
4547
std::vector<std::unique_ptr<BaseDynamicDecodeLayer>> layers_;
4648
};
4749

src/turbomind/models/llama/LlamaV2.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ LlamaV2::LlamaV2(DataType dtype,
9090

9191
// using float to avoid data overflow
9292
dynamic_decode_ = std::make_unique<DynamicDecodeLayer>(
93-
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop);
93+
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop, tp_rank_);
9494
}
9595

9696
void LlamaV2::updateEmbedding(char* decoder_input,

0 commit comments

Comments
 (0)