Skip to content

Commit 4c7b135

Browse files
committed
Merge branch 'concedo_experimental' into croco_nex_0
2 parents bf9aecf + 62e33d0 commit 4c7b135

File tree

21 files changed

+1469
-308
lines changed

21 files changed

+1469
-308
lines changed

CMakeLists.txt

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,9 @@ add_library(common2
840840
examples/llava/clip.h
841841
src/unicode.h
842842
src/unicode.cpp
843-
src/unicode-data.cpp)
843+
src/unicode-data.cpp
844+
otherarch/utils.cpp
845+
otherarch/utils.h)
844846
target_include_directories(common2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
845847
target_compile_features(common2 PUBLIC cxx_std_17) # don't bump
846848
target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})
@@ -860,11 +862,18 @@ target_compile_features(whisper_adapter PUBLIC cxx_std_17) # don't bump
860862
target_link_libraries(whisper_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
861863
set_target_properties(whisper_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
862864

865+
add_library(tts_adapter
866+
otherarch/tts_adapter.cpp)
867+
target_include_directories(tts_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./examples ./common)
868+
target_compile_features(tts_adapter PUBLIC cxx_std_17) # don't bump
869+
target_link_libraries(tts_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
870+
set_target_properties(tts_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
871+
863872
add_library(gpttype_adapter
864873
gpttype_adapter.cpp)
865874
target_include_directories(gpttype_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
866875
target_compile_features(gpttype_adapter PUBLIC cxx_std_17) # don't bump
867-
target_link_libraries(gpttype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
876+
target_link_libraries(gpttype_adapter PRIVATE common2 ggml ggml_v1 ggml_v2 ggml_v3 ${LLAMA_EXTRA_LIBS})
868877
set_target_properties(gpttype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
869878

870879
if (LLAMA_CUBLAS)
@@ -875,8 +884,16 @@ if (LLAMA_CUBLAS)
875884
set_target_properties(${TARGET} PROPERTIES PREFIX "")
876885
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_cublas")
877886
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
878-
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
887+
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
879888
target_compile_features(${TARGET} PRIVATE cxx_std_17)
889+
890+
add_custom_command(
891+
TARGET koboldcpp_cublas POST_BUILD
892+
COMMAND ${CMAKE_COMMAND} -E copy
893+
$<TARGET_FILE:koboldcpp_cublas> # The generated DLL
894+
${CMAKE_SOURCE_DIR}/ # Destination directory
895+
COMMENT "Copying DLL to parent directory"
896+
)
880897
endif()
881898

882899
if (LLAMA_HIPBLAS)
@@ -887,7 +904,15 @@ if (LLAMA_HIPBLAS)
887904
set_target_properties(${TARGET} PROPERTIES PREFIX "")
888905
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_hipblas")
889906
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
890-
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
907+
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
891908
target_compile_features(${TARGET} PRIVATE cxx_std_17)
909+
910+
add_custom_command(
911+
TARGET koboldcpp_hipblas POST_BUILD
912+
COMMAND ${CMAKE_COMMAND} -E copy
913+
$<TARGET_FILE:koboldcpp_hipblas> # The generated DLL
914+
${CMAKE_SOURCE_DIR}/ # Destination directory
915+
COMMENT "Copying DLL to parent directory"
916+
)
892917
endif()
893918

Makefile

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
.PHONY: finishedmsg
55

66
default: koboldcpp_default koboldcpp_failsafe koboldcpp_noavx2 koboldcpp_clblast koboldcpp_clblast_noavx2 koboldcpp_cublas koboldcpp_hipblas koboldcpp_vulkan koboldcpp_vulkan_noavx2 finishedmsg
7-
tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip whispermain sdmain gguf-split
7+
tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip ttsmain whispermain sdmain gguf-split
88

99
ifndef UNAME_S
1010
UNAME_S := $(shell uname -s)
@@ -96,10 +96,10 @@ endif
9696
CUBLASLD_FLAGS =
9797
CUBLAS_OBJS =
9898

99-
OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o sampling.o
100-
OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o sampling.o
101-
OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants_noavx1.o ggml-cpu-aarch64_noavx1.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o sampling.o
102-
OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o sampling.o
99+
OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o sampling.o kcpputils.o
100+
OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o sampling.o kcpputils.o
101+
OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants_noavx1.o ggml-cpu-aarch64_noavx1.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o sampling.o kcpputils.o
102+
OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o sampling.o kcpputils.o
103103

104104
# OS specific
105105
ifeq ($(UNAME_S),Linux)
@@ -589,6 +589,8 @@ ggml-cpu-cpp.o: ggml/src/ggml-cpu/ggml-cpu.cpp ggml/include/ggml.h ggml/src/ggml
589589
$(CXX) $(CXXFLAGS) -c $< -o $@
590590
gguf.o: ggml/src/gguf.cpp ggml/include/gguf.h
591591
$(CXX) $(CXXFLAGS) -c $< -o $@
592+
kcpputils.o: otherarch/utils.cpp otherarch/utils.h
593+
$(CXX) $(CXXFLAGS) -c $< -o $@
592594

593595
#these have special gpu defines
594596
ggml-backend_default.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
@@ -689,8 +691,12 @@ whispercpp_default.o: otherarch/whispercpp/whisper_adapter.cpp
689691
whispercpp_cublas.o: otherarch/whispercpp/whisper_adapter.cpp
690692
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
691693

694+
#tts objects
695+
tts_default.o: otherarch/tts_adapter.cpp
696+
$(CXX) $(CXXFLAGS) -c $< -o $@
697+
692698
# idiotic "for easier compilation"
693-
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
699+
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
694700
gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
695701
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
696702
gpttype_adapter.o: $(GPTTYPE_ADAPTER)
@@ -735,30 +741,30 @@ vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
735741
$(shell) vulkan-shaders-gen --glslc glslc --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp
736742

737743
#generated libraries
738-
koboldcpp_default: ggml.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
744+
koboldcpp_default: ggml.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
739745
$(DEFAULT_BUILD)
740746

741747
ifdef FAILSAFE_BUILD
742-
koboldcpp_failsafe: ggml_v4_failsafe.o ggml-cpu_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FAILSAFE) $(OBJS)
748+
koboldcpp_failsafe: ggml_v4_failsafe.o ggml-cpu_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FAILSAFE) $(OBJS)
743749
$(FAILSAFE_BUILD)
744750
else
745751
koboldcpp_failsafe:
746752
$(DONOTHING)
747753
endif
748754

749755
ifdef NOAVX2_BUILD
750-
koboldcpp_noavx2: ggml_v4_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
756+
koboldcpp_noavx2: ggml_v4_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
751757
$(NOAVX2_BUILD)
752758
else
753759
koboldcpp_noavx2:
754760
$(DONOTHING)
755761
endif
756762

757763
ifdef CLBLAST_BUILD
758-
koboldcpp_clblast: ggml_v4_clblast.o ggml-cpu_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
764+
koboldcpp_clblast: ggml_v4_clblast.o ggml-cpu_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
759765
$(CLBLAST_BUILD)
760766
ifdef NOAVX2_BUILD
761-
koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml-cpu_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLER) $(OBJS)
767+
koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml-cpu_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLER) $(OBJS)
762768
$(CLBLAST_BUILD)
763769
else
764770
koboldcpp_clblast_noavx2:
@@ -772,26 +778,26 @@ koboldcpp_clblast_noavx2:
772778
endif
773779

774780
ifdef CUBLAS_BUILD
775-
koboldcpp_cublas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS)
781+
koboldcpp_cublas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS)
776782
$(CUBLAS_BUILD)
777783
else
778784
koboldcpp_cublas:
779785
$(DONOTHING)
780786
endif
781787

782788
ifdef HIPBLAS_BUILD
783-
koboldcpp_hipblas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS)
789+
koboldcpp_hipblas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS)
784790
$(HIPBLAS_BUILD)
785791
else
786792
koboldcpp_hipblas:
787793
$(DONOTHING)
788794
endif
789795

790796
ifdef VULKAN_BUILD
791-
koboldcpp_vulkan: ggml_v4_vulkan.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_FULL) $(OBJS)
797+
koboldcpp_vulkan: ggml_v4_vulkan.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_FULL) $(OBJS)
792798
$(VULKAN_BUILD)
793799
ifdef NOAVX2_BUILD
794-
koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_SIMPLE) $(OBJS)
800+
koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_SIMPLE) $(OBJS)
795801
$(VULKAN_BUILD)
796802
else
797803
koboldcpp_vulkan_noavx2:

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,6 +2216,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22162216
params.vocoder.model = value;
22172217
}
22182218
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
2219+
add_opt(common_arg(
2220+
{"--tts-use-guide-tokens"},
2221+
"Use guide tokens to improve TTS word recall",
2222+
[](common_params & params) {
2223+
params.vocoder.use_guide_tokens = true;
2224+
}
2225+
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
22192226

22202227
// model-specific
22212228
add_opt(common_arg(

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ struct common_params_vocoder {
174174

175175
std::string model = ""; // model path // NOLINT
176176
std::string model_url = ""; // model url to download // NOLINT
177+
178+
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
177179
};
178180

179181
struct common_params {

examples/tts/tts.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_model * model) {
425425
prompt_add(prompt, model, "<|im_start|>\n", true, true);
426426
}
427427

428+
static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str)
429+
{
430+
const std::string& delimiter = "<|text_sep|>";
431+
432+
std::vector<llama_token> result;
433+
size_t start = 0;
434+
size_t end = str.find(delimiter);
435+
436+
while (end != std::string::npos) {
437+
std::string current_word = str.substr(start, end - start);
438+
auto tmp = common_tokenize(model, current_word, false, true);
439+
result.push_back(tmp[0]);
440+
start = end + delimiter.length();
441+
end = str.find(delimiter, start);
442+
}
443+
444+
// Add the last part
445+
std::string current_word = str.substr(start);
446+
auto tmp = common_tokenize(model, current_word, false, true);
447+
result.push_back(tmp[0]);
448+
return result;
449+
}
450+
428451
int main(int argc, char ** argv) {
429452
common_params params;
430453

@@ -492,6 +515,7 @@ int main(int argc, char ** argv) {
492515
const auto t_main_start = ggml_time_us();
493516

494517
std::vector<llama_token> codes;
518+
std::vector<llama_token> guide_tokens;
495519

496520
// process prompt and generate voice codes
497521
{
@@ -506,6 +530,10 @@ int main(int argc, char ** argv) {
506530
// convert the input text into the necessary format expected by OuteTTS
507531
{
508532
std::string prompt_clean = process_text(params.prompt);
533+
if(params.vocoder.use_guide_tokens)
534+
{
535+
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
536+
}
509537

510538
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
511539

@@ -715,6 +743,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
715743
int n_past = batch.n_tokens;
716744
int n_decode = 0;
717745

746+
bool next_token_uses_guide_token = true;
747+
718748
while (n_decode <= n_predict) {
719749
// prepare the next batch
720750
common_batch_clear(batch);
@@ -726,7 +756,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
726756
continue;
727757
}
728758

729-
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
759+
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
760+
761+
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
762+
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
763+
{
764+
llama_token guide_token = guide_tokens[0];
765+
guide_tokens.erase(guide_tokens.begin());
766+
new_token_id = guide_token; //ensure correct word fragment is used
767+
}
768+
769+
//this is the token id that always precedes a new word
770+
next_token_uses_guide_token = (new_token_id == 198);
730771

731772
common_sampler_accept(smpl[i], new_token_id, true);
732773

expose.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ extern "C"
238238
return whispertype_generate(inputs);
239239
}
240240

241+
bool tts_load_model(const tts_load_model_inputs inputs)
242+
{
243+
return ttstype_load_model(inputs);
244+
}
245+
tts_generation_outputs tts_generate(const tts_generation_inputs inputs)
246+
{
247+
return ttstype_generate(inputs);
248+
}
249+
241250
const char * new_token(int idx) {
242251
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
243252

0 commit comments

Comments
 (0)