diff --git a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index 99e49e782d..94973a15de 100644 --- a/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -317,7 +317,14 @@ class MainActivity : AppCompatActivity() { ruleFars = ruleFars ?: "", )!! - tts = OfflineTts(assetManager = assets, config = config) + val cacheConfig = getOfflineTtsCacheMechanismConfig( + dataDir = dataDir ?: "", + cacheSize = 20*1024*1024, // Default is 20 MBs + )!! + + val cache = new OfflineTtsCacheMechanism(cacheConfig) + + tts = OfflineTts(assetManager = assets, config = config, cache = cache) } diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt index c96f9f0efc..f1f12079dd 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/MainActivity.kt @@ -34,6 +34,7 @@ import androidx.compose.material3.Slider import androidx.compose.material3.Surface import androidx.compose.material3.Text import androidx.compose.material3.TopAppBar +import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember @@ -48,6 +49,8 @@ import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import java.io.File +import kotlin.math.roundToInt +import kotlinx.coroutines.delay import kotlin.time.TimeSource const val TAG = "sherpa-onnx-tts-engine" @@ -78,6 +81,9 @@ class MainActivity : ComponentActivity() { Log.i(TAG, "Finish initializing AudioTrack") val preferenceHelper = PreferenceHelper(this) + + TtsEngine.cacheSize = preferenceHelper.getTtsMechanismCacheSize() + setContent { SherpaOnnxTtsEngineTheme { // A surface container using the 'background' color from the theme @@ -90,6 +96,17 @@ class MainActivity : ComponentActivity() { }) { Box(modifier = Modifier.padding(it)) { Column(modifier = Modifier.padding(16.dp)) { + // Track used cache size in a mutable state + var usedCacheSize by remember { mutableStateOf(0) } + + // LaunchedEffect to periodically update the used cache size + LaunchedEffect(Unit) { + while (true) { + usedCacheSize = TtsEngine.tts?.getTotalUsedCacheSize() ?: 0 + delay(5000) // Update every 5 seconds + } + } + Column { Text("Speed " + String.format("%.1f", TtsEngine.speed)) Slider( @@ -97,10 +114,26 @@ class MainActivity : ComponentActivity() { onValueChange = { TtsEngine.speed = it preferenceHelper.setSpeed(it) + TtsEngine.tts?.clearCache() // Call the clearCache method + usedCacheSize = 0 // Reset used cache size }, valueRange = 0.2F..3.0F, modifier = Modifier.fillMaxWidth() ) + + Text("Cache Size: ${TtsEngine.cacheSize / (1024 * 1024)}MB (${usedCacheSize / (1024 * 1024)}MB used)") + Slider( + value = TtsEngine.cacheSizeState.value.toFloat(), + onValueChange = { newValue -> + // Round the value to the nearest multiple of 5MB + val roundedValue = (newValue / (5 * 1024 * 1024)).roundToInt() * (5 * 1024 * 1024) + TtsEngine.cacheSize = roundedValue + preferenceHelper.setCacheSize(roundedValue) + TtsEngine.tts?.setCacheSize(roundedValue) + }, + valueRange = 0f..209715200f, // 200MB + modifier = Modifier.fillMaxWidth() + ) } val testTextContent = getSampleText(TtsEngine.lang ?: "") @@ -213,8 +246,8 @@ class MainActivity : ComponentActivity() { val RTF = String.format( "Number of threads: %d\nElapsed: %.3f s\nAudio duration: %.3f s\nRTF: %.3f/%.3f = %.3f", TtsEngine.tts!!.config.model.numThreads, - audioDuration, elapsed, + audioDuration, elapsed, audioDuration, elapsed / audioDuration @@ -277,6 +310,13 @@ class MainActivity : ComponentActivity() { } } + override fun onResume() { + super.onResume() + // Update used cache size when the app is resumed + val usedCacheSize = (TtsEngine.tts?.getTotalUsedCacheSize() ?: 0) + Log.i(TAG, "App resumed. Used cache size: ${usedCacheSize}B") + } + override fun onDestroy() { stopMediaPlayer() super.onDestroy() diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt index 57a6e324ca..b856914be8 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/PreferencesHelper.kt @@ -6,6 +6,7 @@ class PreferenceHelper(context: Context) { private val PREFS_NAME = "com.k2fsa.sherpa.onnx.tts.engine" private val SPEED_KEY = "speed" private val SID_KEY = "speaker_id" + private val CACHE_SIZE_KEY = "cache_size" private val sharedPreferences: SharedPreferences = context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE) @@ -29,4 +30,14 @@ class PreferenceHelper(context: Context) { fun getSid(): Int { return sharedPreferences.getInt(SID_KEY, 0) } -} \ No newline at end of file + + fun setCacheSize(value: Int) { + val editor = sharedPreferences.edit() + editor.putInt(CACHE_SIZE_KEY, value) + editor.apply() + } + + fun getTtsMechanismCacheSize(): Int { + return sharedPreferences.getInt(CACHE_SIZE_KEY, 20*(1024*1024)) // Default cache size is 20MB + } +} diff --git a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt index 2ae628c271..6c5419703b 100644 --- a/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt +++ b/android/SherpaOnnxTtsEngine/app/src/main/java/com/k2fsa/sherpa/onnx/tts/engine/TtsEngine.kt @@ -9,6 +9,7 @@ import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableIntStateOf import com.k2fsa.sherpa.onnx.OfflineTts import com.k2fsa.sherpa.onnx.getOfflineTtsConfig +import com.k2fsa.sherpa.onnx.getOfflineTtsCacheMechanismConfig import java.io.File import java.io.FileOutputStream import java.io.IOException @@ -25,6 +26,7 @@ object TtsEngine { val speedState: MutableState = mutableFloatStateOf(1.0F) + val cacheSizeState: MutableState = mutableIntStateOf(0) val speakerIdState: MutableState = mutableIntStateOf(0) var speed: Float @@ -33,6 +35,12 @@ object TtsEngine { speedState.value = value } + var cacheSize: Int + get() = cacheSizeState.value + set(value) { + cacheSizeState.value = value + } + var speakerId: Int get() = speakerIdState.value set(value) { @@ -166,6 +174,8 @@ object TtsEngine { // // This model supports many languages, e.g., English, Chinese, etc. // We set lang to eng here. + + } fun createTts(context: Context) { @@ -204,10 +214,18 @@ object TtsEngine { ruleFars = ruleFars ?: "" ) + cacheSize = PreferenceHelper(context).getTtsMechanismCacheSize() + val cacheConfig = getOfflineTtsCacheMechanismConfig( + dataDir = dataDir ?: "", + cacheSize = cacheSize, + ) + speed = PreferenceHelper(context).getSpeed() speakerId = PreferenceHelper(context).getSid() - tts = OfflineTts(assetManager = assets, config = config) + val cache = new OfflineTtsCacheMechanism(cacheConfig) + + tts = OfflineTts(assetManager = assets, config = config, cache = cache) } diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 5ee7a50507..8abac13f63 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -156,6 +156,8 @@ if(SHERPA_ONNX_ENABLE_TTS) kokoro-multi-lang-lexicon.cc lexicon.cc melo-tts-lexicon.cc + offline-tts-cache-mechanism-config.cc + offline-tts-cache-mechanism.cc offline-tts-character-frontend.cc offline-tts-frontend.cc offline-tts-impl.cc diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc new file mode 100644 index 0000000000..bd06794fe8 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism-config.cc +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTtsCacheMechanismConfig::Register(ParseOptions *po) { + po->Register("tts-cache-dir", &cache_dir, + "Path to the directory containing dict for espeak-ng."); + po->Register("tts-cache-size", &cache_size, + "Cache size for wav files in bytes. After the cache size is filled, wav files are kept based on usage statstics."); +} + +bool OfflineTtsCacheMechanismConfig::Validate() const { + return true; +} + +std::string OfflineTtsCacheMechanismConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsCacheMechanismConfig("; + os << "cache_dir=\"" << cache_dir << "\", "; + os << "cache_size=" << cache_size << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h new file mode 100644 index 0000000000..2f5d2baba4 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsCacheMechanismConfig { + + std::string cache_dir; + + int32_t cache_size; + + OfflineTtsCacheMechanismConfig() = default; + + OfflineTtsCacheMechanismConfig(const std::string &cache_dir, + int32_t cache_size) + : cache_dir(cache_dir), + cache_size(cache_size) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc new file mode 100644 index 0000000000..a335158253 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.cc @@ -0,0 +1,333 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism.cc +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include // for std::size_t + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/wave-reader.h" +#include "sherpa-onnx/csrc/wave-writer.h" + +namespace sherpa_onnx { + +OfflineTtsCacheMechanism::OfflineTtsCacheMechanism( + const OfflineTtsCacheMechanismConfig &config) + : cache_dir_(config.cache_dir), + cache_size_bytes_(config.cache_size), + used_cache_size_bytes_(0) +{ + // Create the cache directory if it doesn't exist + if (!std::filesystem::exists(cache_dir_)) { + bool dir_created = std::filesystem::create_directory(cache_dir_); + if (!dir_created) { + SHERPA_ONNX_LOGE("Unable to create cache directory: %s", + cache_dir_.c_str()); + SHERPA_ONNX_LOGE("Cache mechanism disabled!"); + cache_mechanism_inited_ = false; + return; + } + } + + if(cache_size_bytes_ == -1) + cache_size_bytes_ = INT32_MAX; // Unlimited cache size + + // Load the repeat counts + LoadRepeatCounts(); + + // Update the cache vector and calculate the total cache size + UpdateCacheVector(); + + // Initialize the last save time + last_save_time_ = std::chrono::steady_clock::now(); + + // Indicate that initialization has been successful + cache_mechanism_inited_ = true; +} + +OfflineTtsCacheMechanism::~OfflineTtsCacheMechanism() { + if (cache_mechanism_inited_ == false) return; + + // Save the repeat counts on destruction + SaveRepeatCounts(); +} + +void OfflineTtsCacheMechanism::AddWavFile( + const std::size_t &text_hash, + const std::vector &samples, + const int32_t sample_rate) { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav"; + + // Check if the file physically exists in the cache directory + bool file_exists = std::filesystem::exists(file_path); + + if (!file_exists) { // If the file does not exist, add it to the cache + + // Write the audio samples to a WAV file + bool success = WriteWave(file_path, + sample_rate, samples.data(), samples.size()); + if (success) { + // Calculate size of the new WAV file and add it to the total cache size + std::ifstream file(file_path, std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ += file.tellg(); + } + + // Ensure the cache does not exceed its size limit, non-blocking + EnsureCacheLimit(); + + } else { + SHERPA_ONNX_LOGE("Failed to write wav file: %s", file_path.c_str()); + } + } +} + +std::vector OfflineTtsCacheMechanism::GetWavFile( + const std::size_t &text_hash, + int32_t *sample_rate) { + std::lock_guard lock(mutex_); + + std::vector samples; + + if (cache_mechanism_inited_ == false) return samples; + + std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav"; + + if (std::filesystem::exists(file_path)) { + bool is_ok = false; + samples = ReadWave(file_path, sample_rate, &is_ok); + + if (is_ok == false) { + SHERPA_ONNX_LOGE("Failed to read cached file: %s", file_path.c_str()); + } + } + + // Ensure the text_hash exists in the map before incrementing the count + if (repeat_counts_.find(text_hash) == repeat_counts_.end()) { + repeat_counts_[text_hash] = 1; // Initialize if it doesn't exist + } else { + repeat_counts_[text_hash]++; // Increment the repeat count + } + + // Save the repeat counts every minute + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast( + now - last_save_time_).count() >= 1 * 60) { + SaveRepeatCounts(); + last_save_time_ = now; + } + + return samples; +} + +int32_t OfflineTtsCacheMechanism::GetCacheSize() const { + if (cache_mechanism_inited_ == false) return 0; + + return cache_size_bytes_; +} + +void OfflineTtsCacheMechanism::SetCacheSize(int32_t cache_size) { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + cache_size_bytes_ = cache_size; + + if(cache_size == 0) { + ClearCache(); + } else { + EnsureCacheLimit(); + } +} + +void OfflineTtsCacheMechanism::ClearCache() { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return; + + // Remove all WAV files in the cache directory + for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) { + if (entry.path().extension() == ".wav") { + std::filesystem::remove(entry.path()); + } + } + + // Reset the total cache size to 0 + used_cache_size_bytes_ = 0; + + // Clear the repeat counts and cache vector + repeat_counts_.clear(); + cache_vector_.clear(); + + // Remove repeat counts also in the repeat_counts file + SaveRepeatCounts(); +} + +int32_t OfflineTtsCacheMechanism::GetTotalUsedCacheSize() const { + std::lock_guard lock(mutex_); + + if (cache_mechanism_inited_ == false) return 0; + + return used_cache_size_bytes_; +} + +// Private functions /////////////////////////////////////////////////// + +void OfflineTtsCacheMechanism::LoadRepeatCounts() { + std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin"; + + // Check if the file exists + if (!std::filesystem::exists(repeat_count_file)) { + return; // Skip loading if the file doesn't exist + } + + // Open the file for reading in binary mode + std::ifstream ifs(repeat_count_file, std::ios::binary); + if (!ifs.is_open()) { + SHERPA_ONNX_LOGE("Failed to open repeat count file: %s", + repeat_count_file.c_str()); + return; // Skip loading if the file cannot be opened + } + + // Read the number of entries + size_t num_entries; + ifs.read(reinterpret_cast(&num_entries), sizeof(num_entries)); + + // Read each entry + for (size_t i = 0; i < num_entries; ++i) { + std::size_t text_hash; + std::size_t count; + ifs.read(reinterpret_cast(&text_hash), sizeof(text_hash)); + ifs.read(reinterpret_cast(&count), sizeof(count)); + repeat_counts_[text_hash] = count; + } +} + +void OfflineTtsCacheMechanism::SaveRepeatCounts() { + // Start timing + auto start_time = std::chrono::steady_clock::now(); + + std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin"; + + // Open the file for writing in binary mode + std::ofstream ofs(repeat_count_file, std::ios::binary); + if (!ofs.is_open()) { + SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s", + repeat_count_file.c_str()); + return; // Skip saving if the file cannot be opened + } + + // Write the number of entries + size_t num_entries = repeat_counts_.size(); + ofs.write(reinterpret_cast(&num_entries), sizeof(num_entries)); + + // Write each entry + for (const auto &entry : repeat_counts_) { + ofs.write(reinterpret_cast(&entry.first), sizeof(entry.first)); + ofs.write(reinterpret_cast(&entry.second), sizeof(entry.second)); + } + + // End timing + auto end_time = std::chrono::steady_clock::now(); + auto elapsed_time = std::chrono::duration_cast(end_time - start_time).count(); + + // Print the time taken + SHERPA_ONNX_LOGE("SaveRepeatCounts took %lld milliseconds", elapsed_time); +} + +void OfflineTtsCacheMechanism::RemoveWavFile(const std::size_t &text_hash) { + std::string file_path = cache_dir_ + "/" + + std::to_string(text_hash) + ".wav"; + if (std::filesystem::exists(file_path)) { + // Subtract the size of the removed WAV file from the total cache size + std::ifstream file(file_path, std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ -= file.tellg(); + file.close(); + } + std::filesystem::remove(file_path); + } + + // Remove the entry from the repeat counts and cache vector + if (repeat_counts_.find(text_hash) != repeat_counts_.end()) { + repeat_counts_.erase(text_hash); + cache_vector_.erase( + std::remove(cache_vector_.begin(), cache_vector_.end(), text_hash), + cache_vector_.end()); + } +} + +void OfflineTtsCacheMechanism::UpdateCacheVector() { + used_cache_size_bytes_ = 0; // Reset total cache size before recalculating + + for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) { + if (entry.path().extension() == ".wav") { + std::string text_hash_str = entry.path().stem().string(); + std::size_t text_hash = std::stoull(text_hash_str); + if (repeat_counts_.find(text_hash) == repeat_counts_.end()) { + // Remove the file if it's not in the repeat count file (orphaned file) + std::filesystem::remove(entry.path()); + } else { + // Add the size of the WAV file to the total cache size + std::ifstream file(entry.path(), std::ios::binary | std::ios::ate); + if (file.is_open()) { + used_cache_size_bytes_ += file.tellg(); + } + cache_vector_.push_back(text_hash); + } + } + } +} + +void OfflineTtsCacheMechanism::EnsureCacheLimit() { + std::lock_guard lock(mutex_); // Lock the mutex for the entire function + + if (used_cache_size_bytes_ > cache_size_bytes_) { + // Launch a new thread to handle cache cleanup in a non-blocking way + std::thread([this]() { + std::lock_guard lock(mutex_); // Lock the mutex for the cleanup process + + auto target_cache_size = std::max(static_cast(cache_size_bytes_ * 0.95), 0); + while (used_cache_size_bytes_ > 0 + && used_cache_size_bytes_ > target_cache_size) { + // Cache is full, remove the least repeated file + std::size_t least_repeated_file = GetLeastRepeatedFile(); + RemoveWavFile(least_repeated_file); + } + }).detach(); // Detach the thread to run independently + } +} + +std::size_t OfflineTtsCacheMechanism::GetLeastRepeatedFile() { + std::size_t least_repeated_file = 0; + std::size_t min_count = std::numeric_limits::max(); + + for (const auto &entry : repeat_counts_) { + if (entry.second <= 1) { + least_repeated_file = entry.first; + return least_repeated_file; + } + + if (entry.second < min_count) { + min_count = entry.second; + least_repeated_file = entry.first; + } + } + + return least_repeated_file; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-cache-mechanism.h b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h new file mode 100644 index 0000000000..9945b1f3cd --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-cache-mechanism.h @@ -0,0 +1,92 @@ +// sherpa-onnx/csrc/offline-tts-cache-mechanism.h +// +// Copyright (c) 2025 @mah92 From Iranian people to the community with love + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ + +#include +#include +#include +#include // NOLINT +#include // for std::size_t + +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h" + +namespace sherpa_onnx { + +class OfflineTtsCacheMechanism { + public: + explicit OfflineTtsCacheMechanism(const OfflineTtsCacheMechanismConfig &config); + ~OfflineTtsCacheMechanism(); + + // Add a new wav file to the cache + void AddWavFile( + const std::size_t &text_hash, + const std::vector &samples, + const int32_t sample_rate); + + // Get the cached wav file if it exists + std::vector GetWavFile( + const std::size_t &text_hash, + int32_t *sample_rate); + + // Get the current cache size in bytes + int32_t GetCacheSize() const; + + // Set the cache size in bytes + void SetCacheSize(int32_t cache_size); + + // Remove all the wav files in the cache + void ClearCache(); + + // To get total used cache size(for wav files) in bytes + int32_t GetTotalUsedCacheSize() const; + + private: + // Load the repeat count file + void LoadRepeatCounts(); + + // Save the repeat count file + void SaveRepeatCounts(); + + // Remove a wav file from the cache + void RemoveWavFile(const std::size_t &text_hash); + + // Update the cache vector with the actual files in the cache folder + void UpdateCacheVector(); + + // Reduce used cache size if needed + void EnsureCacheLimit(); + + // Get the least repeated file in the cache + std::size_t GetLeastRepeatedFile(); + + // Data directory where the cache folder is located + std::string cache_dir_; + + // Maximum number of bytes in the cache + int32_t cache_size_bytes_; + + // Total used cache size for wav files in bytes + int32_t used_cache_size_bytes_; + + // Map of text hash to repeat count + std::unordered_map repeat_counts_; + + // Vector of cached file names + std::vector cache_vector_; + + // Mutex for thread safety (recursive to avoid deadlocks) + mutable std::recursive_mutex mutex_; + + // Time of last save + std::chrono::steady_clock::time_point last_save_time_; + + // if cache mechanism is inited successfully + bool cache_mechanism_inited_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CACHE_MECHANISM_H_ diff --git a/sherpa-onnx/csrc/offline-tts-impl.cc b/sherpa-onnx/csrc/offline-tts-impl.cc index 199b0f7926..31e73fb343 100644 --- a/sherpa-onnx/csrc/offline-tts-impl.cc +++ b/sherpa-onnx/csrc/offline-tts-impl.cc @@ -35,7 +35,8 @@ std::vector OfflineTtsImpl::AddBlank(const std::vector &x, } std::unique_ptr OfflineTtsImpl::Create( - const OfflineTtsConfig &config) { + const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache) { + cache_ = cache; if (!config.model.vits.model.empty()) { return std::make_unique(config); } else if (!config.model.matcha.acoustic_model.empty()) { @@ -47,7 +48,8 @@ std::unique_ptr OfflineTtsImpl::Create( template std::unique_ptr OfflineTtsImpl::Create( - Manager *mgr, const OfflineTtsConfig &config) { + Manager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache) { + cache_ = cache; if (!config.model.vits.model.empty()) { return std::make_unique(mgr, config); } else if (!config.model.matcha.acoustic_model.empty()) { @@ -59,12 +61,58 @@ std::unique_ptr OfflineTtsImpl::Create( #if __ANDROID_API__ >= 9 template std::unique_ptr OfflineTtsImpl::Create( - AAssetManager *mgr, const OfflineTtsConfig &config); + AAssetManager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache); #endif - + #if __OHOS__ template std::unique_ptr OfflineTtsImpl::Create( - NativeResourceManager *mgr, const OfflineTtsConfig &config); + NativeResourceManager *mgr, const OfflineTtsConfig &config, OfflineTtsCacheMechanism* cache); #endif + +GeneratedAudio OfflineTtsImpl::GenerateWitchCache( + const std::string &text, int64_t sid, float speed, + GeneratedAudioCallback callback) const +{ + // Generate a hash for the text + std::hash hasher; + std::size_t text_hash = hasher(text); + + //In phones, long texts come from messages, websites and book which are usually not repeated. Repeated text comes from menus and settings which are usually short + bool text_is_long = text.length() > 50? true: false; + + // Check if the cache mechanism is active and if the audio is already cached + if (cache_ && !text_is_long) { + int32_t sample_rate; + std::vector samples + = cache_->GetWavFile(text_hash, &sample_rate); + + if (!samples.empty()) { + SHERPA_ONNX_LOGE("Returning cached audio for hash: %zu", text_hash); + + // If a callback is provided, call it with the cached audio + if (callback) { + int32_t result + = callback(samples.data(), samples.size(), 1.0f /* progress */); + if (result == 0) { + // If the callback returns 0, stop further processing + SHERPA_ONNX_LOGE("Callback requested to stop processing."); + return {samples, sample_rate}; + } + } + + // Return the cached audio + return {samples, sample_rate}; + } + } + + auto audio = Generate(text, sid, speed, callback); + // Cache the generated audio if the cache mechanism is active + if (cache_ && !text_is_long) { + cache_->AddWavFile(text_hash, audio.samples, audio.sample_rate); + } + + return audio; +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-impl.h b/sherpa-onnx/csrc/offline-tts-impl.h index 061acc747c..f7d4847340 100644 --- a/sherpa-onnx/csrc/offline-tts-impl.h +++ b/sherpa-onnx/csrc/offline-tts-impl.h @@ -17,16 +17,23 @@ class OfflineTtsImpl { public: virtual ~OfflineTtsImpl() = default; - static std::unique_ptr Create(const OfflineTtsConfig &config); + static std::unique_ptr Create( + const OfflineTtsConfig &config, + OfflineTtsCacheMechanism* cache = nullptr); template static std::unique_ptr Create(Manager *mgr, - const OfflineTtsConfig &config); + const OfflineTtsConfig &config, + OfflineTtsCacheMechanism* cache = nullptr); virtual GeneratedAudio Generate( const std::string &text, int64_t sid = 0, float speed = 1.0, GeneratedAudioCallback callback = nullptr) const = 0; + GeneratedAudio GenerateWitchCache( + const std::string &text, int64_t sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const; + // Return the sample rate of the generated audio virtual int32_t SampleRate() const = 0; @@ -36,6 +43,8 @@ class OfflineTtsImpl { std::vector AddBlank(const std::vector &x, int32_t blank_id = 0) const; + private: + static OfflineTtsCacheMechanism *cache_; // not owned here }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts.cc b/sherpa-onnx/csrc/offline-tts.cc index d3fa14b51e..ea4e6ea646 100644 --- a/sherpa-onnx/csrc/offline-tts.cc +++ b/sherpa-onnx/csrc/offline-tts.cc @@ -19,6 +19,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -161,39 +162,43 @@ std::string OfflineTtsConfig::ToString() const { return os.str(); } -OfflineTts::OfflineTts(const OfflineTtsConfig &config) - : impl_(OfflineTtsImpl::Create(config)) {} +OfflineTts::OfflineTts(const OfflineTtsConfig &config, + OfflineTtsCacheMechanism *cache) + : impl_(OfflineTtsImpl::Create(config, cache)) {} template -OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config) - : impl_(OfflineTtsImpl::Create(mgr, config)) {} +OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config, + OfflineTtsCacheMechanism *cache) + : impl_(OfflineTtsImpl::Create(mgr, config, cache)) {} OfflineTts::~OfflineTts() = default; GeneratedAudio OfflineTts::Generate( const std::string &text, int64_t sid /*=0*/, float speed /*= 1.0*/, GeneratedAudioCallback callback /*= nullptr*/) const { -#if !defined(_WIN32) - return impl_->Generate(text, sid, speed, std::move(callback)); -#else + + // Generate the audio if not cached + #if !defined(_WIN32) + return impl_->GenerateWitchCache(text, sid, speed, std::move(callback)); + #else if (IsUtf8(text)) { - return impl_->Generate(text, sid, speed, std::move(callback)); + return impl_->GenerateWitchCache(text, sid, speed, std::move(callback)); } else if (IsGB2312(text)) { auto utf8_text = Gb2312ToUtf8(text); static bool printed = false; if (!printed) { SHERPA_ONNX_LOGE( - "Detected GB2312 encoded string! Converting it to UTF8."); - printed = true; - } - return impl_->Generate(utf8_text, sid, speed, std::move(callback)); - } else { - SHERPA_ONNX_LOGE( + "Detected GB2312 encoded string! Converting it to UTF8."); + printed = true; + } + return impl_->GenerateWitchCache(utf8_text, sid, speed, std::move(callback)); + } else { + SHERPA_ONNX_LOGE( "Non UTF8 encoded string is received. You would not get expected " "results!"); - return impl_->Generate(text, sid, speed, std::move(callback)); - } -#endif + return impl_->GenerateWitchCache(text, sid, speed, std::move(callback)); + } + #endif } int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } @@ -202,12 +207,14 @@ int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } #if __ANDROID_API__ >= 9 template OfflineTts::OfflineTts(AAssetManager *mgr, - const OfflineTtsConfig &config); + const OfflineTtsConfig &config, + OfflineTtsCacheMechanism *cache = nullptr); #endif #if __OHOS__ template OfflineTts::OfflineTts(NativeResourceManager *mgr, - const OfflineTtsConfig &config); + const OfflineTtsConfig &config, + OfflineTtsCacheMechanism *cache = nullptr); #endif } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts.h b/sherpa-onnx/csrc/offline-tts.h index a505bd38cb..83f9b16be2 100644 --- a/sherpa-onnx/csrc/offline-tts.h +++ b/sherpa-onnx/csrc/offline-tts.h @@ -10,6 +10,7 @@ #include #include +#include "sherpa-onnx/csrc/offline-tts-cache-mechanism.h" #include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -17,6 +18,7 @@ namespace sherpa_onnx { struct OfflineTtsConfig { OfflineTtsModelConfig model; + // If not empty, it contains a list of rule FST filenames. // Filenames are separated by a comma. // Example value: rule1.fst,rule2,fst,rule3.fst @@ -73,10 +75,12 @@ using GeneratedAudioCallback = std::function - OfflineTts(Manager *mgr, const OfflineTtsConfig &config); + OfflineTts(Manager *mgr, const OfflineTtsConfig &config, + OfflineTtsCacheMechanism *cache = nullptr); // @param text A string containing words separated by spaces // @param sid Speaker ID. Used only for multi-speaker models, e.g., models diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index 14d8cc4f36..1d6a29a8ac 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -193,11 +193,29 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { return ans; } +static OfflineTtsCacheMechanismConfig GetOfflineTtsCacheConfig(JNIEnv *env, jobject config) { + OfflineTtsCacheMechanismConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID(cls, "cacheDir", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.cache_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "cacheSize", "I"); + ans.cache_size = env->GetIntField(config, fid); + + return ans; +} + } // namespace sherpa_onnx SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config, jobject _cache_config) { #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { @@ -205,6 +223,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( return 0; } #endif + auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); @@ -212,22 +231,19 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( #if __ANDROID_API__ >= 9 mgr, #endif - config); + config, + cache); return (jlong)tts; } SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( - JNIEnv *env, jobject /*obj*/, jobject _config) { + JNIEnv *env, jobject /*obj*/, jobject _config, jobject _cache_config) { auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); - if (!config.Validate()) { - SHERPA_ONNX_LOGE("Errors found in config!"); - } - - auto tts = new sherpa_onnx::OfflineTts(config); + auto tts = new sherpa_onnx::OfflineTts(config, cache); return (jlong)tts; } @@ -238,6 +254,24 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_delete( delete reinterpret_cast(ptr); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_setCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jint cacheSize) { + if(cache_) { + cache_->SetCacheSize(static_cast(cacheSize)); + } +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + if(cache_) { + return cache_->GetCacheSize(); + } + + return 0; +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getSampleRate( JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { @@ -250,6 +284,28 @@ JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getNumSpeakers( return reinterpret_cast(ptr)->NumSpeakers(); } +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_getTotalUsedCacheSizeImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + if(cache_) { + return cache_->GetTotalUsedCacheSize(); + } + + return 0; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_OfflineTts_clearCacheImpl( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + if(cache_) { + cache_->ClearCache(); + static int times = 0; + times++; + SHERPA_ONNX_LOGE("Cache cleared from JNI for %ith time\n", times); + } +} + SHERPA_ONNX_EXTERN_C JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, diff --git a/sherpa-onnx/kotlin-api/Tts.kt b/sherpa-onnx/kotlin-api/Tts.kt index b4e0798408..db3381903d 100644 --- a/sherpa-onnx/kotlin-api/Tts.kt +++ b/sherpa-onnx/kotlin-api/Tts.kt @@ -52,6 +52,11 @@ data class OfflineTtsConfig( var silenceScale: Float = 0.2f, ) +data class OfflineTtsCacheMechanismConfig( + var cacheDir: String = "", + var cacheSize: Int = -1, // Unlimited +) + class GeneratedAudio( val samples: FloatArray, val sampleRate: Int, @@ -69,14 +74,15 @@ class GeneratedAudio( class OfflineTts( assetManager: AssetManager? = null, var config: OfflineTtsConfig, + var cache: OfflineTtsCacheMechanism, ) { private var ptr: Long init { ptr = if (assetManager != null) { - newFromAsset(assetManager, config) + newFromAsset(assetManager, config, cache) } else { - newFromFile(config) + newFromFile(config, cache) } } @@ -118,9 +124,9 @@ class OfflineTts( fun allocate(assetManager: AssetManager? = null) { if (ptr == 0L) { ptr = if (assetManager != null) { - newFromAsset(assetManager, config) + newFromAsset(assetManager, config, cacheConfig) } else { - newFromFile(config) + newFromFile(config, cacheConfig) } } } @@ -144,16 +150,34 @@ class OfflineTts( private external fun newFromAsset( assetManager: AssetManager, config: OfflineTtsConfig, + cacheConfig: OfflineTtsCacheMechanismConfig, ): Long private external fun newFromFile( config: OfflineTtsConfig, + cacheConfig: OfflineTtsCacheMechanismConfig, ): Long private external fun delete(ptr: Long) private external fun getSampleRate(ptr: Long): Int private external fun getNumSpeakers(ptr: Long): Int + fun getTtsMechanismCacheSize(): Int { + return (getCacheSizeImpl(ptr)).toInt() + } + private external fun getCacheSizeImpl(ptr: Long): Int + + fun setCacheSize(cacheSize: Int) { + setCacheSizeImpl(ptr, cacheSize) + } + private external fun setCacheSizeImpl(ptr: Long, cacheSize: Int) + + fun getTotalUsedCacheSize(): Int { + return (getTotalUsedCacheSizeImpl(ptr)).toInt() + } + + private external fun getTotalUsedCacheSizeImpl(ptr: Long): Int + // The returned array has two entries: // - the first entry is an 1-D float array containing audio samples. // Each sample is normalized to the range [-1, 1] @@ -173,6 +197,12 @@ class OfflineTts( callback: (samples: FloatArray) -> Int ): Array + fun clearCache() { + clearCacheImpl(ptr) + } + + private external fun clearCacheImpl(ptr: Long) + companion object { init { System.loadLibrary("sherpa-onnx-jni") @@ -281,3 +311,14 @@ fun getOfflineTtsConfig( ruleFars = ruleFars, ) } + +fun getOfflineTtsCacheMechanismConfig( + dataDir: String, + cacheSize: Int +): OfflineTtsCacheMechanismConfig { + + return OfflineTtsCacheMechanismConfig( + cacheDir = "$dataDir/../cache", + cacheSize = cacheSize, + ) +}