Skip to content

Commit 38e68b1

Browse files
Expose reset() method. (#14589)
Summary: . Differential Revision: D83220816 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 57e469c commit 38e68b1

File tree

6 files changed

+65
-7
lines changed

6 files changed

+65
-7
lines changed

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,16 @@ withTokenCallback:(nullable void (^)(NSString *))callback
184184
error:(NSError **)error;
185185

186186
/**
187-
Stops any ongoing generation and cleans up internal resources.
187+
Stop producing new tokens and terminate the current generation process.
188188
*/
189189
- (void)stop;
190190

191+
/**
192+
Remove the prefilled tokens from the KV cache and resets the start position
193+
to 0. It also clears the stats for previous runs.
194+
*/
195+
- (void)reset;
196+
191197
+ (instancetype)new NS_UNAVAILABLE;
192198
- (instancetype)init NS_UNAVAILABLE;
193199

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,10 @@ - (void)stop {
216216
}
217217
}
218218

219+
- (void)reset {
220+
if (_runner) {
221+
_runner->reset();
222+
}
223+
}
224+
219225
@end

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,16 @@ withTokenCallback:(nullable void (^)(NSString *))callback
6464
error:(NSError **)error;
6565

6666
/**
67-
Stops any ongoing generation and cleans up internal resources.
67+
Stop producing new tokens and terminate the current generation process.
6868
*/
6969
- (void)stop;
7070

71+
/**
72+
Remove the prefilled tokens from the KV cache and resets the start position
73+
to 0. It also clears the stats for previous runs.
74+
*/
75+
- (void)reset;
76+
7177
+ (instancetype)new NS_UNAVAILABLE;
7278
- (instancetype)init NS_UNAVAILABLE;
7379

extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.mm

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,10 @@ - (void)stop {
101101
}
102102
}
103103

104+
- (void)reset {
105+
if (_runner) {
106+
_runner->reset();
107+
}
108+
}
109+
104110
@end

extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ extension UIImage {
4545
}
4646

4747
class MultimodalRunnerTest: XCTestCase {
48+
let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
49+
let assistantPrompt = "ASSISTANT: "
50+
let userPrompt = "What's on the picture?"
51+
let sequenceLength = 768
52+
4853
func test() {
4954
let bundle = Bundle(for: type(of: self))
5055
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
@@ -59,10 +64,25 @@ class MultimodalRunnerTest: XCTestCase {
5964

6065
do {
6166
try runner.generate([
62-
MultimodalInput("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "),
67+
MultimodalInput(systemPrompt),
68+
MultimodalInput(image.asImage()),
69+
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
70+
], sequenceLength: sequenceLength) { token in
71+
text += token
72+
}
73+
} catch {
74+
XCTFail("Failed to generate text with error \(error)")
75+
}
76+
XCTAssertTrue(text.lowercased().contains("waterfall"))
77+
78+
text = ""
79+
runner.reset()
80+
do {
81+
try runner.generate([
82+
MultimodalInput(systemPrompt),
6383
MultimodalInput(image.asImage()),
64-
MultimodalInput("What's on the picture? ASSISTANT: "),
65-
], sequenceLength: 768) { token in
84+
MultimodalInput("\(userPrompt) \(assistantPrompt)"),
85+
], sequenceLength: sequenceLength) { token in
6686
text += token
6787
}
6888
} catch {

extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ struct SpecialTokens {
3636
}
3737

3838
class TextRunnerTest: XCTestCase {
39+
let userPrompt = "The capital of France is called"
40+
let sequenceLength = 128
41+
3942
func test() {
4043
let bundle = Bundle(for: type(of: self))
4144
guard let modelPath = bundle.path(forResource: "llama3_2-1B", ofType: "pte"),
@@ -47,12 +50,23 @@ class TextRunnerTest: XCTestCase {
4750
var text = ""
4851

4952
do {
50-
try runner.generate("hello", sequenceLength: 2) { token in
53+
try runner.generate(userPrompt, sequenceLength: sequenceLength) { token in
54+
text += token
55+
}
56+
} catch {
57+
XCTFail("Failed to generate text with error \(error)")
58+
}
59+
XCTAssertTrue(text.lowercased().contains("paris"))
60+
61+
text = ""
62+
runner.reset()
63+
do {
64+
try runner.generate(userPrompt, sequenceLength: sequenceLength) { token in
5165
text += token
5266
}
5367
} catch {
5468
XCTFail("Failed to generate text with error \(error)")
5569
}
56-
XCTAssertEqual("hello,", text.lowercased())
70+
XCTAssertTrue(text.lowercased().contains("paris"))
5771
}
5872
}

0 commit comments

Comments
 (0)