Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
import 'package:runanywhere/core/types/model_types.dart';
import 'package:runanywhere/foundation/error_types/sdk_error.dart';

/// Configuration for the LLM component.
///
/// Mirrors the validation contract used by the Swift and Kotlin SDKs so
/// invalid parameters fail in Dart before crossing the FFI boundary.
class LLMConfiguration implements ComponentConfiguration {
final String? modelId;
final InferenceFramework? preferredFramework;
final int contextLength;
final double temperature;
final int maxTokens;
final String? systemPrompt;
final bool streamingEnabled;

const LLMConfiguration({
this.modelId,
this.preferredFramework,
this.contextLength = 2048,
this.temperature = 0.7,
this.maxTokens = 100,
this.systemPrompt,
this.streamingEnabled = true,
});

@override
void validate() {
if (contextLength <= 0 || contextLength > 32768) {
throw SDKError.validationFailed(
'Context length must be between 1 and 32768',
);
}

if (!temperature.isFinite || temperature < 0 || temperature > 2.0) {
throw SDKError.validationFailed(
'Temperature must be between 0 and 2.0',
);
}

if (maxTokens <= 0 || maxTokens > contextLength) {
throw SDKError.validationFailed(
'Max tokens must be between 1 and context length',
);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
import 'package:runanywhere/core/types/model_types.dart';
import 'package:runanywhere/foundation/error_types/sdk_error.dart';

/// Configuration for the STT component.
///
/// Mirrors the validation contract used by the Swift and Kotlin SDKs so
/// invalid parameters fail in Dart before crossing the FFI boundary.
class STTConfiguration implements ComponentConfiguration {
final String? modelId;
final InferenceFramework? preferredFramework;
final String language;
final int sampleRate;
final bool enablePunctuation;
final bool enableDiarization;
final List<String> vocabularyList;
final int maxAlternatives;
final bool enableTimestamps;

const STTConfiguration({
this.modelId,
this.preferredFramework,
this.language = 'en-US',
this.sampleRate = 16000,
this.enablePunctuation = true,
this.enableDiarization = false,
this.vocabularyList = const <String>[],
this.maxAlternatives = 1,
this.enableTimestamps = true,
});

@override
void validate() {
if (sampleRate <= 0 || sampleRate > 48000) {
throw SDKError.validationFailed(
'Sample rate must be between 1 and 48000 Hz',
);
}

if (maxAlternatives <= 0 || maxAlternatives > 10) {
throw SDKError.validationFailed(
'Max alternatives must be between 1 and 10',
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,7 @@ import 'dart:async';
import 'dart:typed_data';

import 'package:flutter_tts/flutter_tts.dart';

/// Configuration for TTS synthesis
class TTSConfiguration {
final String voice;
final String language;
final double speakingRate;
final double pitch;
final double volume;
final String audioFormat;

const TTSConfiguration({
this.voice = 'system',
this.language = 'en-US',
this.speakingRate = 0.5,
this.pitch = 1.0,
this.volume = 1.0,
this.audioFormat = 'pcm',
});
}
import 'package:runanywhere/features/tts/tts_configuration.dart';

/// Input for TTS synthesis
class TTSSynthesisInput {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import 'package:runanywhere/core/protocols/component/component_configuration.dart';
import 'package:runanywhere/foundation/error_types/sdk_error.dart';

/// Configuration for TTS synthesis.
class TTSConfiguration implements ComponentConfiguration {
final String voice;
final String language;
final double speakingRate;
final double pitch;
final double volume;
final String audioFormat;

const TTSConfiguration({
this.voice = 'system',
this.language = 'en-US',
this.speakingRate = 0.5,
this.pitch = 1.0,
this.volume = 1.0,
this.audioFormat = 'pcm',
});

@override
void validate() {
if (!speakingRate.isFinite || speakingRate < 0.5 || speakingRate > 2.0) {
throw SDKError.validationFailed(
'Speaking rate must be between 0.5 and 2.0',
);
}

if (!pitch.isFinite || pitch < 0.5 || pitch > 2.0) {
throw SDKError.validationFailed(
'Pitch must be between 0.5 and 2.0',
);
}

if (!volume.isFinite || volume < 0.0 || volume > 1.0) {
throw SDKError.validationFailed(
'Volume must be between 0.0 and 1.0',
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import 'dart:ffi';
import 'dart:isolate'; // Keep for non-streaming generation

import 'package:ffi/ffi.dart';

import 'package:runanywhere/features/llm/llm_configuration.dart';
import 'package:runanywhere/foundation/error_types/sdk_error.dart';
import 'package:runanywhere/foundation/logging/sdk_logger.dart';
import 'package:runanywhere/native/ffi_types.dart';
import 'package:runanywhere/native/platform_loader.dart';
Expand Down Expand Up @@ -48,6 +49,7 @@ class DartBridgeLLM {

RacHandle? _handle;
String? _loadedModelId;
int? _loadedContextLength;
final _logger = SDKLogger('DartBridge.LLM');

/// Active stream subscription for cancellation
Expand Down Expand Up @@ -153,6 +155,7 @@ class DartBridgeLLM {
String modelPath,
String modelId,
String modelName,
int? contextLength,
) async {
final handle = getHandle();

Expand Down Expand Up @@ -181,6 +184,7 @@ class DartBridgeLLM {
}

_loadedModelId = modelId;
_loadedContextLength = contextLength;
_logger.info('LLM model loaded: $modelId');
} finally {
calloc.free(pathPtr);
Expand All @@ -200,6 +204,7 @@ class DartBridgeLLM {

cleanupFn(_handle!);
_loadedModelId = null;
_loadedContextLength = null;
_logger.info('LLM model unloaded');
} catch (e) {
_logger.error('Failed to unload LLM model: $e');
Expand Down Expand Up @@ -247,6 +252,13 @@ class DartBridgeLLM {
throw StateError('No LLM model loaded. Call loadModel() first.');
}

_validateGenerationParameters(
contextLength: _requireLoadedContextLength(),
maxTokens: maxTokens,
temperature: temperature,
systemPrompt: systemPrompt,
);

// Run FFI call in a separate isolate to avoid heap corruption
// from C++ background threads (Metal GPU operations)
final handleAddress = handle.address;
Expand Down Expand Up @@ -290,6 +302,14 @@ class DartBridgeLLM {
throw StateError('No LLM model loaded. Call loadModel() first.');
}

_validateGenerationParameters(
contextLength: _requireLoadedContextLength(),
maxTokens: maxTokens,
temperature: temperature,
systemPrompt: systemPrompt,
streamingEnabled: true,
);

// Create stream controller for emitting tokens to the caller
final controller = StreamController<String>();

Expand Down Expand Up @@ -367,6 +387,33 @@ class DartBridgeLLM {
}
}

int _requireLoadedContextLength() {
final contextLength = _loadedContextLength;
if (contextLength != null && contextLength > 0) {
return contextLength;
}

throw SDKError.validationFailed(
'Loaded model is missing context length metadata for maxTokens validation',
);
}

void _validateGenerationParameters({
required int contextLength,
required int maxTokens,
required double temperature,
String? systemPrompt,
bool streamingEnabled = false,
}) {
LLMConfiguration(
contextLength: contextLength,
maxTokens: maxTokens,
temperature: temperature,
systemPrompt: systemPrompt,
streamingEnabled: streamingEnabled,
).validate();
}

// MARK: - Cleanup

/// Destroy the component and release resources.
Expand All @@ -380,6 +427,7 @@ class DartBridgeLLM {
destroyFn(_handle!);
_handle = null;
_loadedModelId = null;
_loadedContextLength = null;
_logger.debug('LLM component destroyed');
} catch (e) {
_logger.error('Failed to destroy LLM component: $e');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class DartBridgeModelRegistry {
modelPtr.ref.localPath =
pathDart != null ? strdupFn(pathDart) : nullptr;
modelPtr.ref.downloadSize = model.sizeBytes;
modelPtr.ref.contextLength = model.contextLength;
modelPtr.ref.source = model.source;

final result = saveFn(_registryHandle!, modelPtr);
Expand Down Expand Up @@ -197,6 +198,7 @@ class DartBridgeModelRegistry {
framework: _frameworkToFfi(model.framework),
source: _sourceToFfi(model.source),
sizeBytes: model.downloadSize ?? 0,
contextLength: model.contextLength ?? 0,
downloadURL: model.downloadURL?.toString(),
localPath: model.localPath?.toFilePath(),
version: null,
Expand Down Expand Up @@ -385,6 +387,7 @@ class DartBridgeModelRegistry {
? Uri.file(ffiModel.localPath!)
: null,
downloadSize: ffiModel.sizeBytes > 0 ? ffiModel.sizeBytes : null,
contextLength: ffiModel.contextLength > 0 ? ffiModel.contextLength : null,
source: _sourceFromFfi(ffiModel.source),
);
}
Expand Down Expand Up @@ -797,6 +800,7 @@ class DartBridgeModelRegistry {
framework: struct.ref.framework,
source: struct.ref.source,
sizeBytes: struct.ref.downloadSize,
contextLength: struct.ref.contextLength,
downloadURL: struct.ref.downloadUrl != nullptr
? struct.ref.downloadUrl.toDartString()
: null,
Expand Down Expand Up @@ -1085,6 +1089,7 @@ class ModelInfo {
final int framework;
final int source;
final int sizeBytes;
final int contextLength;
final String? downloadURL;
final String? localPath;
final String? version;
Expand All @@ -1097,6 +1102,7 @@ class ModelInfo {
required this.framework,
required this.source,
required this.sizeBytes,
required this.contextLength,
this.downloadURL,
this.localPath,
this.version,
Expand All @@ -1112,6 +1118,7 @@ class ModelInfo {
'framework': framework,
'source': source,
'sizeBytes': sizeBytes,
'contextLength': contextLength,
if (downloadURL != null) 'downloadURL': downloadURL,
if (localPath != null) 'localPath': localPath,
if (version != null) 'version': version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import 'dart:isolate';
import 'dart:typed_data';

import 'package:ffi/ffi.dart';

import 'package:runanywhere/features/stt/stt_configuration.dart';
import 'package:runanywhere/foundation/logging/sdk_logger.dart';
import 'package:runanywhere/native/ffi_types.dart';
import 'package:runanywhere/native/platform_loader.dart';
Expand Down Expand Up @@ -185,6 +185,8 @@ class DartBridgeSTT {
Uint8List audioData, {
int sampleRate = 16000,
}) async {
STTConfiguration(sampleRate: sampleRate).validate();

final handle = getHandle();

if (!isLoaded) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import 'dart:isolate';
import 'dart:typed_data';

import 'package:ffi/ffi.dart';

import 'package:runanywhere/features/tts/tts_configuration.dart';
import 'package:runanywhere/foundation/logging/sdk_logger.dart';
import 'package:runanywhere/native/ffi_types.dart';
import 'package:runanywhere/native/platform_loader.dart';
Expand Down Expand Up @@ -190,6 +190,12 @@ class DartBridgeTTS {
double pitch = 1.0,
double volume = 1.0,
}) async {
TTSConfiguration(
speakingRate: rate,
pitch: pitch,
volume: volume,
).validate();

final handle = getHandle();

if (!isLoaded) {
Expand Down
Loading