diff --git a/README.md b/README.md
index 974607c..2211ce4 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
+
{word}
-
{start} - {end} s
+
{start:.2f} - {end:.2f} s
+ {f'
{speaker_label}
' if speaker_label else ''}
"""
@@ -438,6 +449,7 @@ def run(audio_upload: Any, lang_disp: str, return_ts: bool):
text=getattr(t, "text", None),
start_time=getattr(t, "start_time", None),
end_time=getattr(t, "end_time", None),
+ speaker=getattr(t, "speaker", None),
)
for t in (getattr(r, "time_stamps", None) or [])
]
@@ -490,6 +502,10 @@ def main(argv=None) -> int:
user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs")
user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs")
+ # Add campplus model path to aligner kwargs
+ if hasattr(args, "campplus_model") and args.campplus_model:
+ user_aligner_kwargs["campplus_model"] = args.campplus_model
+
backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs)
backend_kwargs = _coerce_special_types(backend_kwargs)
diff --git a/qwen_asr/inference/cluster_backend.py b/qwen_asr/inference/cluster_backend.py
new file mode 100644
index 0000000..80479ad
--- /dev/null
+++ b/qwen_asr/inference/cluster_backend.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
+
+import scipy
+import torch
+import sklearn
+import numpy as np
+
+from sklearn.cluster._kmeans import k_means
+from sklearn.cluster import HDBSCAN
+
+
+class SpectralCluster:
+ r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
+ This implementation is adapted from https://github.com/speechbrain/speechbrain.
+ """
+
+ def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
+ self.min_num_spks = min_num_spks
+ self.max_num_spks = max_num_spks
+ self.pval = pval
+
+ def __call__(self, X, oracle_num=None):
+ # Similarity matrix computation
+ sim_mat = self.get_sim_mat(X)
+
+ # Refining similarity matrix with pval
+ prunned_sim_mat = self.p_pruning(sim_mat)
+
+ # Symmetrization
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
+
+ # Laplacian calculation
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
+
+ # Get Spectral Embeddings
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
+
+ # Perform clustering
+ labels = self.cluster_embs(emb, num_of_spk)
+
+ return labels
+
+ def get_sim_mat(self, X):
+ # Cosine similarities
+ M = sklearn.metrics.pairwise.cosine_similarity(X, X)
+ return M
+
+ def p_pruning(self, A):
+ if A.shape[0] * self.pval < 6:
+ pval = 6.0 / A.shape[0]
+ else:
+ pval = self.pval
+
+ n_elems = int((1 - pval) * A.shape[0])
+
+ # For each row in a affinity matrix
+ for i in range(A.shape[0]):
+ low_indexes = np.argsort(A[i, :])
+ low_indexes = low_indexes[0:n_elems]
+
+ # Replace smaller similarity values by 0s
+ A[i, low_indexes] = 0
+ return A
+
+ def get_laplacian(self, M):
+ M[np.diag_indices(M.shape[0])] = 0
+ D = np.sum(np.abs(M), axis=1)
+ D = np.diag(D)
+ L = D - M
+ return L
+
+ def get_spec_embs(self, L, k_oracle=None):
+ lambdas, eig_vecs = scipy.linalg.eigh(L)
+
+ if k_oracle is not None:
+ num_of_spk = k_oracle
+ else:
+ lambda_gap_list = self.getEigenGaps(
+ lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
+ )
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
+
+ emb = eig_vecs[:, :num_of_spk]
+ return emb, num_of_spk
+
+ def cluster_embs(self, emb, k):
+ _, labels, _ = k_means(emb, k)
+ return labels
+
+ def getEigenGaps(self, eig_vals):
+ eig_vals_gap_list = []
+ for i in range(len(eig_vals) - 1):
+ gap = float(eig_vals[i + 1]) - float(eig_vals[i])
+ eig_vals_gap_list.append(gap)
+ return eig_vals_gap_list
+
+
+class UmapHdbscan:
+ r"""
+ Reference:
+ - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
+ Emphasis On Topological Structure. ICASSP2022
+ """
+
+ def __init__(
+ self, n_neighbors=20, n_components=60, min_samples=10, min_cluster_size=10, metric="cosine"
+ ):
+ self.n_neighbors = n_neighbors
+ self.n_components = n_components
+ self.min_samples = min_samples
+ self.min_cluster_size = min_cluster_size
+ self.metric = metric
+
+ def __call__(self, X):
+ import umap.umap_ as umap
+
+ umap_X = umap.UMAP(
+ n_neighbors=self.n_neighbors,
+ min_dist=0.0,
+ n_components=min(self.n_components, X.shape[0] - 2),
+ metric=self.metric,
+ ).fit_transform(X)
+ labels = HDBSCAN(
+ min_samples=self.min_samples,
+ min_cluster_size=self.min_cluster_size,
+ allow_single_cluster=True,
+ ).fit_predict(umap_X)
+ return labels
+
+
+class ClusterBackend(torch.nn.Module):
+ r"""Perfom clustering for input embeddings and output the labels.
+ Args:
+ model_dir: A model dir.
+ model_config: The model config.
+ """
+
+ def __init__(self, merge_thr=0.78):
+ super().__init__()
+ self.model_config = {"merge_thr": merge_thr}
+ # self.other_config = kwargs
+
+ self.spectral_cluster = SpectralCluster()
+ self.umap_hdbscan_cluster = UmapHdbscan()
+
+ def forward(self, X, **params):
+ # clustering and return the labels
+ k = params["oracle_num"] if "oracle_num" in params else None
+ assert len(X.shape) == 2, "modelscope error: the shape of input should be [N, C]"
+ if X.shape[0] < 20:
+ return np.zeros(X.shape[0], dtype="int")
+ if X.shape[0] < 2048 or k is not None:
+ # default
+ # unexpected corner case
+ labels = self.spectral_cluster(X, k)
+ else:
+ labels = self.umap_hdbscan_cluster(X)
+
+ if k is None and "merge_thr" in self.model_config:
+ labels = self.merge_by_cos(labels, X, self.model_config["merge_thr"])
+
+ return labels
+
+ def merge_by_cos(self, labels, embs, cos_thr):
+ # merge the similar speakers by cosine similarity
+ assert cos_thr > 0 and cos_thr <= 1
+ while True:
+ spk_num = labels.max() + 1
+ if spk_num == 1:
+ break
+ spk_center = []
+ for i in range(spk_num):
+ spk_emb = embs[labels == i].mean(0)
+ spk_center.append(spk_emb)
+ assert len(spk_center) > 0
+ spk_center = np.stack(spk_center, axis=0)
+ norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True)
+ affinity = np.matmul(norm_spk_center, norm_spk_center.T)
+ affinity = np.triu(affinity, 1)
+ spks = np.unravel_index(np.argmax(affinity), affinity.shape)
+ if affinity[spks] < cos_thr:
+ break
+ for i in range(len(labels)):
+ if labels[i] == spks[1]:
+ labels[i] = spks[0]
+ elif labels[i] > spks[1]:
+ labels[i] -= 1
+ return labels
diff --git a/qwen_asr/inference/qwen3_asr.py b/qwen_asr/inference/qwen3_asr.py
index d99b915..1d7ff85 100644
--- a/qwen_asr/inference/qwen3_asr.py
+++ b/qwen_asr/inference/qwen3_asr.py
@@ -557,7 +557,8 @@ def _offset_align_result(self, result: Any, offset_sec: float) -> Any:
for it in result.items:
items.append(type(it)(text=it.text,
start_time=round(it.start_time + offset_sec, 3),
- end_time=round(it.end_time + offset_sec, 3)))
+ end_time=round(it.end_time + offset_sec, 3),
+ speaker=getattr(it, 'speaker', None)))
return type(result)(items=items)
def _merge_align_results(self, results: List[Any]) -> Optional[Any]:
@@ -733,17 +734,8 @@ def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) ->
prefix = ""
else:
cur_ids = self.processor.tokenizer.encode(state._raw_decoded)
- k = int(state.unfixed_token_num)
- while True:
- end_idx = max(0, len(cur_ids) - k)
- prefix = self.processor.tokenizer.decode(cur_ids[:end_idx]) if end_idx > 0 else ""
- if '\ufffd' not in prefix:
- break
- else:
- if end_idx == 0:
- prefix = ""
- break
- k += 1
+ end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num))
+ prefix = self.processor.tokenizer.decode(cur_ids[:end_idx])
prompt = state.prompt_raw + prefix
diff --git a/qwen_asr/inference/qwen3_forced_aligner.py b/qwen_asr/inference/qwen3_forced_aligner.py
index 76fe043..3788620 100644
--- a/qwen_asr/inference/qwen3_forced_aligner.py
+++ b/qwen_asr/inference/qwen3_forced_aligner.py
@@ -14,23 +14,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
+import logging
+import soundfile as sf
import unicodedata
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
-
+import regex
import nagisa
+import numpy as np
import torch
+import torchaudio
+
from qwen_asr.core.transformers_backend import (
Qwen3ASRConfig,
Qwen3ASRForConditionalGeneration,
Qwen3ASRProcessor,
)
from transformers import AutoConfig, AutoModel, AutoProcessor
-
+import torchaudio.compliance.kaldi as kaldi
+from qwen_asr.inference.cluster_backend import ClusterBackend
+import onnxruntime
+
+logger = logging.getLogger(__name__)
+
+
+def _resolve_campplus_model_path(campplus_model: str) -> str:
+ if not campplus_model:
+ return None
+ if os.path.isfile(campplus_model):
+ return campplus_model
+ if "/" in campplus_model and not os.path.exists(campplus_model):
+ parts = campplus_model.split("/")
+ if len(parts) >= 2:
+ repo_id = "/".join(parts[:-1])
+ filename = parts[-1]
+ try:
+ from modelscope.hub.file_download import model_file_download
+ cache_path = model_file_download(
+ model_id=repo_id,
+ file_path=filename,
+ )
+ logger.info(f"Downloaded campplus model from ModelScope: {campplus_model} -> {cache_path}")
+ return cache_path
+ except ImportError:
+ pass
+ except Exception as e:
+ logger.warning(f"Failed to download from ModelScope: {e}")
+ try:
+ from huggingface_hub import hf_hub_download
+ cache_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ )
+ logger.info(f"Downloaded campplus model from HuggingFace: {campplus_model} -> {cache_path}")
+ return cache_path
+ except ImportError:
+ pass
+ except Exception as e:
+ logger.warning(f"Failed to download from HuggingFace: {e}")
+ raise FileNotFoundError(
+ f"Cannot find campplus model at '{campplus_model}'. "
+ f"Please install modelscope (pip install modelscope) or huggingface_hub (pip install huggingface_hub), "
+ f"or provide a valid local path."
+ )
+ return campplus_model
from .utils import (
AudioLike,
ensure_list,
normalize_audios,
+ SAMPLE_RATE,
)
@@ -279,10 +331,13 @@ class ForcedAlignItem:
Start time in seconds.
end_time (float):
End time in seconds.
+ speaker (Optional[int]):
+ Speaker label (cluster ID) for this segment.
"""
text: str
start_time: int
end_time: int
+ speaker: Optional[int] = None
@dataclass(frozen=True)
@@ -322,10 +377,20 @@ def __init__(
model: Qwen3ASRForConditionalGeneration,
processor: Qwen3ASRProcessor,
aligner_processor: Qwen3ForceAlignProcessor,
+ campplus_model: str = None,
):
self.model = model
self.processor = processor
self.aligner_processor = aligner_processor
+ # refer to https://github.com/FunAudioLLM/CosyVoice.git
+ if campplus_model:
+ resolved_path = _resolve_campplus_model_path(campplus_model)
+ self.campplus_model = resolved_path
+ self.cb_model = ClusterBackend().to('cpu')
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ self.campplus_session = onnxruntime.InferenceSession(resolved_path, sess_options=option, providers=["CPUExecutionProvider"])
self.device = getattr(model, "device", None)
if self.device is None:
@@ -368,6 +433,9 @@ def from_pretrained(
AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration)
AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor)
+ # Extract campplus_model from kwargs before passing to AutoModel
+ campplus_model = kwargs.pop('campplus_model', None)
+
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
if not isinstance(model, Qwen3ASRForConditionalGeneration):
raise TypeError(
@@ -377,7 +445,7 @@ def from_pretrained(
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True)
aligner_processor = Qwen3ForceAlignProcessor()
- return cls(model=model, processor=processor, aligner_processor=aligner_processor)
+ return cls(model=model, processor=processor, aligner_processor=aligner_processor, campplus_model=campplus_model)
def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult:
items: List[ForcedAlignItem] = []
@@ -385,12 +453,26 @@ def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> Forced
items.append(
ForcedAlignItem(
text=str(it.get("text", "")),
- start_time=float(it.get("start_time", 0)),
- end_time=float(it.get("end_time", 0)),
+ start_time=int(it.get("start_time", 0)),
+ end_time=int(it.get("end_time", 0)),
)
)
return ForcedAlignResult(items=items)
-
+
+ @staticmethod
+ def load_wav(wav, target_sr, min_sr=16000):
+ if isinstance(wav, str):
+ speech, sample_rate = sf.read(wav)
+ speech = torch.from_numpy(speech).float().unsqueeze(0)
+ else:
+ speech = torch.from_numpy(wav).float().unsqueeze(0)
+ sample_rate = target_sr
+
+ if sample_rate != target_sr:
+ assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
+ speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
+ return speech
+
@torch.inference_mode()
def align(
self,
@@ -419,14 +501,14 @@ def align(
"""
texts = ensure_list(text)
languages = ensure_list(language)
- audios = normalize_audios(audio)
+ norm_audio = normalize_audios(audio)
- if len(languages) == 1 and len(audios) > 1:
- languages = languages * len(audios)
+ if len(languages) == 1 and len(norm_audio) > 1:
+ languages = languages * len(norm_audio)
- if not (len(audios) == len(texts) == len(languages)):
+ if not (len(norm_audio) == len(texts) == len(languages)):
raise ValueError(
- f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}"
+ f"Batch size mismatch: audio={len(norm_audio)}, text={len(texts)}, language={len(languages)}"
)
word_lists = []
@@ -438,7 +520,7 @@ def align(
inputs = self.processor(
text=aligner_input_texts,
- audio=audios,
+ audio=norm_audio,
return_tensors="pt",
padding=True,
)
@@ -452,12 +534,149 @@ def align(
masked_output_id = output_id[input_id == self.timestamp_token_id]
timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy()
timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms)
- for it in timestamp_output:
- it['start_time'] = round(it['start_time'] / 1000.0, 3)
- it['end_time'] = round(it['end_time'] / 1000.0, 3)
+ if not self.campplus_model:
+ for it in timestamp_output:
+ it['start_time'] = round(it['start_time'] / 1000.0, 3)
+ it['end_time'] = round(it['end_time'] / 1000.0, 3)
results.append(self._to_structured_items(timestamp_output))
+ if not self.campplus_model:
+ return results
+
+ audio_segments, text_segments = self._extract_audio_segments(texts, results, norm_audio, languages)
+ spk_labels = self._cluster_speakers(audio_segments)
+ return self._assign_speaker_labels(texts, results, text_segments, spk_labels)
- return results
+ def _extract_audio_segments(
+ self,
+ texts: List[str],
+ results: List["ForcedAlignResult"],
+ norm_audio: List,
+ languages: List[str],
+ ) -> tuple:
+ audio_segments = []
+ text_segments = []
+ fbank_min_samples = int(25 / 1000 * SAMPLE_RATE)
+
+ for text, align_result, audio_ndarray, language in zip(texts, results, norm_audio, languages):
+ segments_by_punc = [s for s in regex.split(r'[\p{P}]', text) if s]
+ char_offset = 0
+ max_idx = len(align_result) - 1
+
+ for seg in segments_by_punc:
+ seg_len = len(seg)
+ start_time = align_result[min(char_offset, max_idx)].start_time
+ end_time = align_result[min(char_offset + seg_len - 1, max_idx)].end_time
+
+ start_sample = int(start_time * (SAMPLE_RATE / 1000))
+ end_sample = int(end_time * (SAMPLE_RATE / 1000))
+
+ if end_sample - start_sample >= fbank_min_samples:
+ audio_segments.append(audio_ndarray[start_sample:end_sample])
+ text_segments.append({
+ 'text': seg,
+ 'start': start_time,
+ 'end': end_time,
+ 'language': language,
+ 'audio_idx': len(audio_segments) - 1,
+ })
+
+ char_offset += seg_len
+
+ return audio_segments, text_segments
+
+ def _extract_embedding(self, wav: np.ndarray) -> torch.Tensor:
+ speech = self.load_wav(wav, target_sr=SAMPLE_RATE)
+ feat = kaldi.fbank(
+ speech,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=SAMPLE_RATE,
+ )
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = self.campplus_session.run(
+ None,
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()},
+ )[0].flatten()
+ return torch.from_numpy(embedding)
+
+ def _cluster_speakers(self, audio_segments: List) -> np.ndarray:
+ if not audio_segments:
+ return np.array([], dtype=int)
+
+ embeddings = torch.stack([self._extract_embedding(wav) for wav in audio_segments])
+ return self.cb_model(embeddings)
+
+ def _assign_speaker_labels(
+ self,
+ texts: List[str],
+ results: List["ForcedAlignResult"],
+ text_segments: List[dict],
+ spk_labels: np.ndarray,
+ ) -> List["ForcedAlignResult"]:
+ segment_idx = 0
+ time_offset = 0
+ final_results = []
+
+ for align_result in results:
+ new_items = []
+ for item in align_result:
+ speaker = self._find_speaker_for_item(
+ item, text_segments, spk_labels, segment_idx, time_offset
+ )
+ new_items.append(ForcedAlignItem(
+ text=item.text,
+ start_time=item.start_time,
+ end_time=item.end_time,
+ speaker=speaker,
+ ))
+ segment_idx, _ = self._advance_segment(
+ item, text_segments, segment_idx, time_offset
+ )
+
+ if align_result:
+ time_offset += align_result[-1].end_time
+ final_results.append(ForcedAlignResult(items=new_items))
+
+ return final_results
+
+ def _find_speaker_for_item(
+ self,
+ item: "ForcedAlignItem",
+ text_segments: List[dict],
+ spk_labels: np.ndarray,
+ segment_idx: int,
+ time_offset: int,
+ ) -> Optional[int]:
+ while segment_idx < len(text_segments):
+ seg = text_segments[segment_idx]
+ seg_start = seg['start'] - time_offset
+ seg_end = seg['end'] - time_offset
+
+ if item.start_time >= seg_start and item.end_time <= seg_end:
+ audio_idx = seg['audio_idx']
+ return int(spk_labels[audio_idx]) if audio_idx < len(spk_labels) else None
+ elif item.end_time > seg_end:
+ segment_idx += 1
+ else:
+ break
+ return None
+
+ def _advance_segment(
+ self,
+ item: "ForcedAlignItem",
+ text_segments: List[dict],
+ segment_idx: int,
+ time_offset: int,
+ ) -> tuple:
+ while segment_idx < len(text_segments):
+ seg = text_segments[segment_idx]
+ seg_end = seg['end'] - time_offset
+
+ if item.end_time > seg_end:
+ segment_idx += 1
+ else:
+ break
+ return segment_idx, time_offset
def get_supported_languages(self) -> Optional[List[str]]:
"""