@@ -148,9 +148,9 @@ struct ContentView: View {
148148}
149149
150150@Observable
151+ @MainActor
151152class LLMEvaluator {
152153
153- @MainActor
154154 var running = false
155155
156156 var output = " "
@@ -172,91 +172,87 @@ class LLMEvaluator {
172172
173173 enum LoadState {
174174 case idle
175- case loaded( LLMModel , Tokenizers . Tokenizer )
175+ case loaded( ModelContainer )
176176 }
177177
178178 var loadState = LoadState . idle
179179
180180 /// load and return the model -- can be called multiple times, subsequent calls will
181181 /// just return the loaded model
182- func load( ) async throws -> ( LLMModel , Tokenizers . Tokenizer ) {
182+ func load( ) async throws -> ModelContainer {
183183 switch loadState {
184184 case . idle:
185185 // limit the buffer cache
186186 MLX . GPU. set ( cacheLimit: 20 * 1024 * 1024 )
187187
188- let ( model, tokenizer) = try await LLM . load ( configuration: modelConfiguration) {
188+ let modelContainer = try await LLM . loadModelContainer ( configuration: modelConfiguration)
189+ {
189190 [ modelConfiguration] progress in
190- DispatchQueue . main . sync {
191+ Task { @ MainActor in
191192 self . modelInfo =
192193 " Downloading \( modelConfiguration. name) : \( Int ( progress. fractionCompleted * 100 ) ) % "
193194 }
194195 }
195196 self . modelInfo =
196197 " Loaded \( modelConfiguration. id) . Weights: \( MLX . GPU. activeMemory / 1024 / 1024 ) M "
197- loadState = . loaded( model , tokenizer )
198- return ( model , tokenizer )
198+ loadState = . loaded( modelContainer )
199+ return modelContainer
199200
200- case . loaded( let model , let tokenizer ) :
201- return ( model , tokenizer )
201+ case . loaded( let modelContainer ) :
202+ return modelContainer
202203 }
203204 }
204205
205206 func generate( prompt: String ) async {
206- let canGenerate = await MainActor . run {
207- if running {
208- return false
209- } else {
210- running = true
211- self . output = " "
212- return true
213- }
214- }
207+ guard !running else { return }
215208
216- guard canGenerate else { return }
209+ running = true
210+ self . output = " "
217211
218212 do {
219- let ( model, tokenizer) = try await load ( )
213+ let modelContainer = try await load ( )
214+
220215 // augment the prompt as needed
221216 let prompt = modelConfiguration. prepare ( prompt: prompt)
222- let promptTokens = tokenizer. encode ( text: prompt)
217+
218+ let promptTokens = await modelContainer. perform { _, tokenizer in
219+ tokenizer. encode ( text: prompt)
220+ }
223221
224222 // each time you generate you will get something new
225223 MLXRandom . seed ( UInt64 ( Date . timeIntervalSinceReferenceDate * 1000 ) )
226224
227- let result = await LLM . generate (
228- promptTokens: promptTokens, parameters: generateParameters, model: model,
229- tokenizer: tokenizer, extraEOSTokens: modelConfiguration. extraEOSTokens
230- ) { tokens in
231- // update the output -- this will make the view show the text as it generates
232- if tokens. count % displayEveryNTokens == 0 {
233- let text = tokenizer. decode ( tokens: tokens)
234- await MainActor . run {
235- self . output = text
225+ let result = await modelContainer. perform { model, tokenizer in
226+ LLM . generate (
227+ promptTokens: promptTokens, parameters: generateParameters, model: model,
228+ tokenizer: tokenizer, extraEOSTokens: modelConfiguration. extraEOSTokens
229+ ) { tokens in
230+ // update the output -- this will make the view show the text as it generates
231+ if tokens. count % displayEveryNTokens == 0 {
232+ let text = tokenizer. decode ( tokens: tokens)
233+ Task { @MainActor in
234+ self . output = text
235+ }
236236 }
237- }
238237
239- if tokens. count >= maxTokens {
240- return . stop
241- } else {
242- return . more
238+ if tokens. count >= maxTokens {
239+ return . stop
240+ } else {
241+ return . more
242+ }
243243 }
244244 }
245245
246246 // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
247- await MainActor . run {
248- if result. output != self . output {
249- self . output = result. output
250- }
251- running = false
252- self . stat = " Tokens/second: \( String ( format: " %.3f " , result. tokensPerSecond) ) "
247+ if result. output != self . output {
248+ self . output = result. output
253249 }
250+ self . stat = " Tokens/second: \( String ( format: " %.3f " , result. tokensPerSecond) ) "
254251
255252 } catch {
256- await MainActor . run {
257- running = false
258- output = " Failed: \( error) "
259- }
253+ output = " Failed: \( error) "
260254 }
255+
256+ running = false
261257 }
262258}
0 commit comments