diff --git a/README.md b/README.md index 974607c..2211ce4 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

-  🤗 Hugging Face   |   🤖 ModelScope   |   📑 Blog   |   📑 Paper   +  🤗 Hugging Face   |   🤖 ModelScope   |   📑 Blog   |   📑 Paper  
🖥️ Hugging Face Demo   |    🖥️ ModelScope Demo   |   💬 WeChat (微信)   |   🫨 Discord   |   📑 API @@ -41,7 +41,7 @@ We release **Qwen3-ASR**, a family that includes two powerful all-in-one speech - [Fine Tuning](#fine-tuning) - [Docker](#docker) - [Evaluation](#evaluation) -- [Citation](#citation) + ## Overview @@ -1420,18 +1420,18 @@ During evaluation, we ran inference for all models with `dtype=torch.bfloat16` a -## Citation + ## Star History diff --git a/pyproject.toml b/pyproject.toml index 97a4ae1..93ea350 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "qwen-asr" -version = "0.0.6" +version = "0.0.4" description = "Qwen-ASR python package" readme = "README.md" requires-python = ">=3.9" diff --git a/qwen_asr/cli/demo.py b/qwen_asr/cli/demo.py index 5f93ba5..70eff75 100644 --- a/qwen_asr/cli/demo.py +++ b/qwen_asr/cli/demo.py @@ -144,6 +144,12 @@ def build_parser() -> argparse.ArgumentParser: help="Qwen3-ForcedAligner checkpoint path or HF repo id (optional; enables timestamps when provided).", ) + parser.add_argument( + "--campplus-model", + default="FunAudioLLM/Fun-CosyVoice3-0.5B-2512/campplus.onnx", + help="Campplus model path for speaker diarization (optional).", + ) + parser.add_argument( "--backend", default="transformers", @@ -306,6 +312,7 @@ def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str: word = str(item.get("text", "") or "") start = item.get("start_time", None) end = item.get("end_time", None) + speaker = item.get("speaker", None) if start is None or end is None: continue @@ -328,13 +335,17 @@ def _make_timestamp_html(audio_upload: Any, timestamps: Any) -> str: b64 = base64.b64encode(mem.read()).decode("utf-8") audio_src = f"data:audio/wav;base64,{b64}" + speaker_label = f"Speaker {speaker}" if speaker is not None else "" + speaker_style = f"border-left: 4px solid hsl({(speaker % 12) * 30}, 70%, 50%);" if speaker is not None else "" + html_content += f""" -

+
{word}
-
{start} - {end} s
+
{start:.2f} - {end:.2f} s
+ {f'
{speaker_label}
' if speaker_label else ''}
""" @@ -438,6 +449,7 @@ def run(audio_upload: Any, lang_disp: str, return_ts: bool): text=getattr(t, "text", None), start_time=getattr(t, "start_time", None), end_time=getattr(t, "end_time", None), + speaker=getattr(t, "speaker", None), ) for t in (getattr(r, "time_stamps", None) or []) ] @@ -490,6 +502,10 @@ def main(argv=None) -> int: user_backend_kwargs = _parse_json_dict(args.backend_kwargs, name="--backend-kwargs") user_aligner_kwargs = _parse_json_dict(args.aligner_kwargs, name="--aligner-kwargs") + # Add campplus model path to aligner kwargs + if hasattr(args, "campplus_model") and args.campplus_model: + user_aligner_kwargs["campplus_model"] = args.campplus_model + backend_kwargs = _merge_dicts(_default_backend_kwargs(backend), user_backend_kwargs) backend_kwargs = _coerce_special_types(backend_kwargs) diff --git a/qwen_asr/inference/cluster_backend.py b/qwen_asr/inference/cluster_backend.py new file mode 100644 index 0000000..80479ad --- /dev/null +++ b/qwen_asr/inference/cluster_backend.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) +# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) + +import scipy +import torch +import sklearn +import numpy as np + +from sklearn.cluster._kmeans import k_means +from sklearn.cluster import HDBSCAN + + +class SpectralCluster: + r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix. + This implementation is adapted from https://github.com/speechbrain/speechbrain. + """ + + def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022): + self.min_num_spks = min_num_spks + self.max_num_spks = max_num_spks + self.pval = pval + + def __call__(self, X, oracle_num=None): + # Similarity matrix computation + sim_mat = self.get_sim_mat(X) + + # Refining similarity matrix with pval + prunned_sim_mat = self.p_pruning(sim_mat) + + # Symmetrization + sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T) + + # Laplacian calculation + laplacian = self.get_laplacian(sym_prund_sim_mat) + + # Get Spectral Embeddings + emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num) + + # Perform clustering + labels = self.cluster_embs(emb, num_of_spk) + + return labels + + def get_sim_mat(self, X): + # Cosine similarities + M = sklearn.metrics.pairwise.cosine_similarity(X, X) + return M + + def p_pruning(self, A): + if A.shape[0] * self.pval < 6: + pval = 6.0 / A.shape[0] + else: + pval = self.pval + + n_elems = int((1 - pval) * A.shape[0]) + + # For each row in a affinity matrix + for i in range(A.shape[0]): + low_indexes = np.argsort(A[i, :]) + low_indexes = low_indexes[0:n_elems] + + # Replace smaller similarity values by 0s + A[i, low_indexes] = 0 + return A + + def get_laplacian(self, M): + M[np.diag_indices(M.shape[0])] = 0 + D = np.sum(np.abs(M), axis=1) + D = np.diag(D) + L = D - M + return L + + def get_spec_embs(self, L, k_oracle=None): + lambdas, eig_vecs = scipy.linalg.eigh(L) + + if k_oracle is not None: + num_of_spk = k_oracle + else: + lambda_gap_list = self.getEigenGaps( + lambdas[self.min_num_spks - 1 : self.max_num_spks + 1] + ) + num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks + + emb = eig_vecs[:, :num_of_spk] + return emb, num_of_spk + + def cluster_embs(self, emb, k): + _, labels, _ = k_means(emb, k) + return labels + + def getEigenGaps(self, eig_vals): + eig_vals_gap_list = [] + for i in range(len(eig_vals) - 1): + gap = float(eig_vals[i + 1]) - float(eig_vals[i]) + eig_vals_gap_list.append(gap) + return eig_vals_gap_list + + +class UmapHdbscan: + r""" + Reference: + - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With + Emphasis On Topological Structure. ICASSP2022 + """ + + def __init__( + self, n_neighbors=20, n_components=60, min_samples=10, min_cluster_size=10, metric="cosine" + ): + self.n_neighbors = n_neighbors + self.n_components = n_components + self.min_samples = min_samples + self.min_cluster_size = min_cluster_size + self.metric = metric + + def __call__(self, X): + import umap.umap_ as umap + + umap_X = umap.UMAP( + n_neighbors=self.n_neighbors, + min_dist=0.0, + n_components=min(self.n_components, X.shape[0] - 2), + metric=self.metric, + ).fit_transform(X) + labels = HDBSCAN( + min_samples=self.min_samples, + min_cluster_size=self.min_cluster_size, + allow_single_cluster=True, + ).fit_predict(umap_X) + return labels + + +class ClusterBackend(torch.nn.Module): + r"""Perfom clustering for input embeddings and output the labels. + Args: + model_dir: A model dir. + model_config: The model config. + """ + + def __init__(self, merge_thr=0.78): + super().__init__() + self.model_config = {"merge_thr": merge_thr} + # self.other_config = kwargs + + self.spectral_cluster = SpectralCluster() + self.umap_hdbscan_cluster = UmapHdbscan() + + def forward(self, X, **params): + # clustering and return the labels + k = params["oracle_num"] if "oracle_num" in params else None + assert len(X.shape) == 2, "modelscope error: the shape of input should be [N, C]" + if X.shape[0] < 20: + return np.zeros(X.shape[0], dtype="int") + if X.shape[0] < 2048 or k is not None: + # default + # unexpected corner case + labels = self.spectral_cluster(X, k) + else: + labels = self.umap_hdbscan_cluster(X) + + if k is None and "merge_thr" in self.model_config: + labels = self.merge_by_cos(labels, X, self.model_config["merge_thr"]) + + return labels + + def merge_by_cos(self, labels, embs, cos_thr): + # merge the similar speakers by cosine similarity + assert cos_thr > 0 and cos_thr <= 1 + while True: + spk_num = labels.max() + 1 + if spk_num == 1: + break + spk_center = [] + for i in range(spk_num): + spk_emb = embs[labels == i].mean(0) + spk_center.append(spk_emb) + assert len(spk_center) > 0 + spk_center = np.stack(spk_center, axis=0) + norm_spk_center = spk_center / np.linalg.norm(spk_center, axis=1, keepdims=True) + affinity = np.matmul(norm_spk_center, norm_spk_center.T) + affinity = np.triu(affinity, 1) + spks = np.unravel_index(np.argmax(affinity), affinity.shape) + if affinity[spks] < cos_thr: + break + for i in range(len(labels)): + if labels[i] == spks[1]: + labels[i] = spks[0] + elif labels[i] > spks[1]: + labels[i] -= 1 + return labels diff --git a/qwen_asr/inference/qwen3_asr.py b/qwen_asr/inference/qwen3_asr.py index d99b915..1d7ff85 100644 --- a/qwen_asr/inference/qwen3_asr.py +++ b/qwen_asr/inference/qwen3_asr.py @@ -557,7 +557,8 @@ def _offset_align_result(self, result: Any, offset_sec: float) -> Any: for it in result.items: items.append(type(it)(text=it.text, start_time=round(it.start_time + offset_sec, 3), - end_time=round(it.end_time + offset_sec, 3))) + end_time=round(it.end_time + offset_sec, 3), + speaker=getattr(it, 'speaker', None))) return type(result)(items=items) def _merge_align_results(self, results: List[Any]) -> Optional[Any]: @@ -733,17 +734,8 @@ def streaming_transcribe(self, pcm16k: np.ndarray, state: ASRStreamingState) -> prefix = "" else: cur_ids = self.processor.tokenizer.encode(state._raw_decoded) - k = int(state.unfixed_token_num) - while True: - end_idx = max(0, len(cur_ids) - k) - prefix = self.processor.tokenizer.decode(cur_ids[:end_idx]) if end_idx > 0 else "" - if '\ufffd' not in prefix: - break - else: - if end_idx == 0: - prefix = "" - break - k += 1 + end_idx = max(1, len(cur_ids) - int(state.unfixed_token_num)) + prefix = self.processor.tokenizer.decode(cur_ids[:end_idx]) prompt = state.prompt_raw + prefix diff --git a/qwen_asr/inference/qwen3_forced_aligner.py b/qwen_asr/inference/qwen3_forced_aligner.py index 76fe043..3788620 100644 --- a/qwen_asr/inference/qwen3_forced_aligner.py +++ b/qwen_asr/inference/qwen3_forced_aligner.py @@ -14,23 +14,75 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import logging +import soundfile as sf import unicodedata from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union - +import regex import nagisa +import numpy as np import torch +import torchaudio + from qwen_asr.core.transformers_backend import ( Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor, ) from transformers import AutoConfig, AutoModel, AutoProcessor - +import torchaudio.compliance.kaldi as kaldi +from qwen_asr.inference.cluster_backend import ClusterBackend +import onnxruntime + +logger = logging.getLogger(__name__) + + +def _resolve_campplus_model_path(campplus_model: str) -> str: + if not campplus_model: + return None + if os.path.isfile(campplus_model): + return campplus_model + if "/" in campplus_model and not os.path.exists(campplus_model): + parts = campplus_model.split("/") + if len(parts) >= 2: + repo_id = "/".join(parts[:-1]) + filename = parts[-1] + try: + from modelscope.hub.file_download import model_file_download + cache_path = model_file_download( + model_id=repo_id, + file_path=filename, + ) + logger.info(f"Downloaded campplus model from ModelScope: {campplus_model} -> {cache_path}") + return cache_path + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to download from ModelScope: {e}") + try: + from huggingface_hub import hf_hub_download + cache_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + ) + logger.info(f"Downloaded campplus model from HuggingFace: {campplus_model} -> {cache_path}") + return cache_path + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to download from HuggingFace: {e}") + raise FileNotFoundError( + f"Cannot find campplus model at '{campplus_model}'. " + f"Please install modelscope (pip install modelscope) or huggingface_hub (pip install huggingface_hub), " + f"or provide a valid local path." + ) + return campplus_model from .utils import ( AudioLike, ensure_list, normalize_audios, + SAMPLE_RATE, ) @@ -279,10 +331,13 @@ class ForcedAlignItem: Start time in seconds. end_time (float): End time in seconds. + speaker (Optional[int]): + Speaker label (cluster ID) for this segment. """ text: str start_time: int end_time: int + speaker: Optional[int] = None @dataclass(frozen=True) @@ -322,10 +377,20 @@ def __init__( model: Qwen3ASRForConditionalGeneration, processor: Qwen3ASRProcessor, aligner_processor: Qwen3ForceAlignProcessor, + campplus_model: str = None, ): self.model = model self.processor = processor self.aligner_processor = aligner_processor + # refer to https://github.com/FunAudioLLM/CosyVoice.git + if campplus_model: + resolved_path = _resolve_campplus_model_path(campplus_model) + self.campplus_model = resolved_path + self.cb_model = ClusterBackend().to('cpu') + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession(resolved_path, sess_options=option, providers=["CPUExecutionProvider"]) self.device = getattr(model, "device", None) if self.device is None: @@ -368,6 +433,9 @@ def from_pretrained( AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) + # Extract campplus_model from kwargs before passing to AutoModel + campplus_model = kwargs.pop('campplus_model', None) + model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs) if not isinstance(model, Qwen3ASRForConditionalGeneration): raise TypeError( @@ -377,7 +445,7 @@ def from_pretrained( processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True) aligner_processor = Qwen3ForceAlignProcessor() - return cls(model=model, processor=processor, aligner_processor=aligner_processor) + return cls(model=model, processor=processor, aligner_processor=aligner_processor, campplus_model=campplus_model) def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> ForcedAlignResult: items: List[ForcedAlignItem] = [] @@ -385,12 +453,26 @@ def _to_structured_items(self, timestamp_output: List[Dict[str, Any]]) -> Forced items.append( ForcedAlignItem( text=str(it.get("text", "")), - start_time=float(it.get("start_time", 0)), - end_time=float(it.get("end_time", 0)), + start_time=int(it.get("start_time", 0)), + end_time=int(it.get("end_time", 0)), ) ) return ForcedAlignResult(items=items) - + + @staticmethod + def load_wav(wav, target_sr, min_sr=16000): + if isinstance(wav, str): + speech, sample_rate = sf.read(wav) + speech = torch.from_numpy(speech).float().unsqueeze(0) + else: + speech = torch.from_numpy(wav).float().unsqueeze(0) + sample_rate = target_sr + + if sample_rate != target_sr: + assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) + speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) + return speech + @torch.inference_mode() def align( self, @@ -419,14 +501,14 @@ def align( """ texts = ensure_list(text) languages = ensure_list(language) - audios = normalize_audios(audio) + norm_audio = normalize_audios(audio) - if len(languages) == 1 and len(audios) > 1: - languages = languages * len(audios) + if len(languages) == 1 and len(norm_audio) > 1: + languages = languages * len(norm_audio) - if not (len(audios) == len(texts) == len(languages)): + if not (len(norm_audio) == len(texts) == len(languages)): raise ValueError( - f"Batch size mismatch: audio={len(audios)}, text={len(texts)}, language={len(languages)}" + f"Batch size mismatch: audio={len(norm_audio)}, text={len(texts)}, language={len(languages)}" ) word_lists = [] @@ -438,7 +520,7 @@ def align( inputs = self.processor( text=aligner_input_texts, - audio=audios, + audio=norm_audio, return_tensors="pt", padding=True, ) @@ -452,12 +534,149 @@ def align( masked_output_id = output_id[input_id == self.timestamp_token_id] timestamp_ms = (masked_output_id * self.timestamp_segment_time).to("cpu").numpy() timestamp_output = self.aligner_processor.parse_timestamp(word_list, timestamp_ms) - for it in timestamp_output: - it['start_time'] = round(it['start_time'] / 1000.0, 3) - it['end_time'] = round(it['end_time'] / 1000.0, 3) + if not self.campplus_model: + for it in timestamp_output: + it['start_time'] = round(it['start_time'] / 1000.0, 3) + it['end_time'] = round(it['end_time'] / 1000.0, 3) results.append(self._to_structured_items(timestamp_output)) + if not self.campplus_model: + return results + + audio_segments, text_segments = self._extract_audio_segments(texts, results, norm_audio, languages) + spk_labels = self._cluster_speakers(audio_segments) + return self._assign_speaker_labels(texts, results, text_segments, spk_labels) - return results + def _extract_audio_segments( + self, + texts: List[str], + results: List["ForcedAlignResult"], + norm_audio: List, + languages: List[str], + ) -> tuple: + audio_segments = [] + text_segments = [] + fbank_min_samples = int(25 / 1000 * SAMPLE_RATE) + + for text, align_result, audio_ndarray, language in zip(texts, results, norm_audio, languages): + segments_by_punc = [s for s in regex.split(r'[\p{P}]', text) if s] + char_offset = 0 + max_idx = len(align_result) - 1 + + for seg in segments_by_punc: + seg_len = len(seg) + start_time = align_result[min(char_offset, max_idx)].start_time + end_time = align_result[min(char_offset + seg_len - 1, max_idx)].end_time + + start_sample = int(start_time * (SAMPLE_RATE / 1000)) + end_sample = int(end_time * (SAMPLE_RATE / 1000)) + + if end_sample - start_sample >= fbank_min_samples: + audio_segments.append(audio_ndarray[start_sample:end_sample]) + text_segments.append({ + 'text': seg, + 'start': start_time, + 'end': end_time, + 'language': language, + 'audio_idx': len(audio_segments) - 1, + }) + + char_offset += seg_len + + return audio_segments, text_segments + + def _extract_embedding(self, wav: np.ndarray) -> torch.Tensor: + speech = self.load_wav(wav, target_sr=SAMPLE_RATE) + feat = kaldi.fbank( + speech, + num_mel_bins=80, + dither=0, + sample_frequency=SAMPLE_RATE, + ) + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = self.campplus_session.run( + None, + {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}, + )[0].flatten() + return torch.from_numpy(embedding) + + def _cluster_speakers(self, audio_segments: List) -> np.ndarray: + if not audio_segments: + return np.array([], dtype=int) + + embeddings = torch.stack([self._extract_embedding(wav) for wav in audio_segments]) + return self.cb_model(embeddings) + + def _assign_speaker_labels( + self, + texts: List[str], + results: List["ForcedAlignResult"], + text_segments: List[dict], + spk_labels: np.ndarray, + ) -> List["ForcedAlignResult"]: + segment_idx = 0 + time_offset = 0 + final_results = [] + + for align_result in results: + new_items = [] + for item in align_result: + speaker = self._find_speaker_for_item( + item, text_segments, spk_labels, segment_idx, time_offset + ) + new_items.append(ForcedAlignItem( + text=item.text, + start_time=item.start_time, + end_time=item.end_time, + speaker=speaker, + )) + segment_idx, _ = self._advance_segment( + item, text_segments, segment_idx, time_offset + ) + + if align_result: + time_offset += align_result[-1].end_time + final_results.append(ForcedAlignResult(items=new_items)) + + return final_results + + def _find_speaker_for_item( + self, + item: "ForcedAlignItem", + text_segments: List[dict], + spk_labels: np.ndarray, + segment_idx: int, + time_offset: int, + ) -> Optional[int]: + while segment_idx < len(text_segments): + seg = text_segments[segment_idx] + seg_start = seg['start'] - time_offset + seg_end = seg['end'] - time_offset + + if item.start_time >= seg_start and item.end_time <= seg_end: + audio_idx = seg['audio_idx'] + return int(spk_labels[audio_idx]) if audio_idx < len(spk_labels) else None + elif item.end_time > seg_end: + segment_idx += 1 + else: + break + return None + + def _advance_segment( + self, + item: "ForcedAlignItem", + text_segments: List[dict], + segment_idx: int, + time_offset: int, + ) -> tuple: + while segment_idx < len(text_segments): + seg = text_segments[segment_idx] + seg_end = seg['end'] - time_offset + + if item.end_time > seg_end: + segment_idx += 1 + else: + break + return segment_idx, time_offset def get_supported_languages(self) -> Optional[List[str]]: """