Skip to content

Commit ff0593e

Browse files
committed
fix: rerank detection
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent a86c04a commit ff0593e

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

file_architecture.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ type (
117117
RoPEScalingOriginalContextLength uint64 `json:"ropeScalingOriginalContextLength,omitempty"`
118118
// RoPEScalingFinetuned is true if the RoPE scaling is fine-tuned.
119119
RoPEScalingFinetuned bool `json:"ropeScalingFinetuned,omitempty"`
120+
// PoolingType is the type of pooling used in the model.
121+
PoolingType uint32 `json:"poolingType,omitempty"`
120122
// SSMConvolutionKernel is the size of the convolution kernel used in the Selective State Space Model (SSM) and similar architectures.
121123
SSMConvolutionKernel uint32 `json:"ssmConvolutionKernel,omitempty"`
122124
// SSMInnerSize is the embedding size of the state in SSM and similar architectures.
@@ -857,6 +859,8 @@ func (gf *GGUFFile) transformerArchitecture(arch string) (ga GGUFArchitecture) {
857859
ropeScalingOriginalContextKey = arch + ".rope.scaling.original_context_length" // uint32 maybe
858860
ropeScalingFinetunedKey = arch + ".rope.scaling.finetuned"
859861

862+
poolingTypeKey = arch + ".pooling_type"
863+
860864
ssmConvolutionKernelKey = arch + ".ssm.conv_kernel"
861865
ssmInnerSizeKey = arch + ".ssm.inner_size"
862866
ssmStateSizeKey = arch + ".ssm.state_size"
@@ -910,6 +914,7 @@ func (gf *GGUFFile) transformerArchitecture(arch string) (ga GGUFArchitecture) {
910914
ropeScalingFactorKey,
911915
ropeScalingOriginalContextKey,
912916
ropeScalingFinetunedKey,
917+
poolingTypeKey,
913918
ssmConvolutionKernelKey,
914919
ssmInnerSizeKey,
915920
ssmStateSizeKey,
@@ -1098,6 +1103,13 @@ func (gf *GGUFFile) transformerArchitecture(arch string) (ga GGUFArchitecture) {
10981103
ga.RoPEScalingFinetuned = v.ValueBool()
10991104
}
11001105

1106+
if v, ok := m[poolingTypeKey]; ok {
1107+
ga.PoolingType = v.ValueUint32()
1108+
if ga.AttentionCausal && ga.PoolingType > 2 {
1109+
ga.AttentionCausal = false
1110+
}
1111+
}
1112+
11011113
if v, ok := m[ssmConvolutionKernelKey]; ok {
11021114
ga.SSMConvolutionKernel = ValueNumeric[uint32](v)
11031115
}

file_estimate__llamacpp.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ func (gf *GGUFFile) estimateLLaMACppRunInModel(o *_GGUFRunEstimateOptions, a *GG
373373
if _, found := gf.TensorInfos.Index([]string{"cls.bias", "cls.weight"}); found > 0 {
374374
e.Reranking = true
375375
}
376+
if !e.Reranking && a.PoolingType == 4 { // 0: None, 1: Mean, 2: Cls, 3: Last, 4: Rank
377+
e.Reranking = true
378+
}
376379
}
377380

378381
// Distributable,

0 commit comments

Comments
 (0)