|
3 | 3 | module DiscourseAi |
4 | 4 | module Completions |
5 | 5 | class StructuredOutput |
6 | | - def initialize(property_names) |
7 | | - @raw_response = +"" |
8 | | - @state = :awaiting_key |
9 | | - @current_key = +"" |
10 | | - @escape = false |
11 | | - |
12 | | - @full_output = |
13 | | - property_names.reduce({}) do |memo, pn| |
14 | | - memo[pn.to_sym] = +"" |
15 | | - memo |
| 6 | + def initialize(json_schema_properties) |
| 7 | + @property_names = json_schema_properties.keys.map(&:to_sym) |
| 8 | + @property_cursors = |
| 9 | + json_schema_properties.reduce({}) do |m, (k, prop)| |
| 10 | + m[k.to_sym] = 0 if prop[:type] == "string" |
| 11 | + m |
16 | 12 | end |
17 | 13 |
|
18 | | - # Partial output is what we processed in the last chunk. |
19 | | - @partial_output_proto = @full_output.deep_dup |
20 | | - @last_chunk_output = @full_output.deep_dup |
| 14 | + @tracked = {} |
| 15 | + |
| 16 | + @partial_json_tracker = JsonStreamingTracker.new(self) |
21 | 17 | end |
22 | 18 |
|
23 | | - attr_reader :full_output, :last_chunk_output |
| 19 | + attr_reader :last_chunk_buffer |
24 | 20 |
|
25 | 21 | def <<(raw) |
26 | | - @raw_response << raw |
27 | | - |
28 | | - @last_chunk_output = @partial_output_proto.deep_dup |
| 22 | + @partial_json_tracker << raw |
| 23 | + end |
29 | 24 |
|
30 | | - raw.each_char do |char| |
31 | | - case @state |
32 | | - when :awaiting_key |
33 | | - if char == "\"" |
34 | | - @current_key = +"" |
35 | | - @state = :parsing_key |
36 | | - @escape = false |
37 | | - end |
38 | | - when :parsing_key |
39 | | - if char == "\"" |
40 | | - @state = :awaiting_colon |
41 | | - else |
42 | | - @current_key << char |
43 | | - end |
44 | | - when :awaiting_colon |
45 | | - @state = :awaiting_value if char == ":" |
46 | | - when :awaiting_value |
47 | | - if char == '"' |
48 | | - @escape = false |
49 | | - @state = :parsing_value |
50 | | - end |
51 | | - when :parsing_value |
52 | | - if @escape |
53 | | - # Don't add escape sequence until we know what it is |
54 | | - unescaped = unescape_char(char) |
55 | | - @full_output[@current_key.to_sym] << unescaped |
56 | | - @last_chunk_output[@current_key.to_sym] << unescaped |
| 25 | + def read_latest_buffered_chunk |
| 26 | + @property_names.reduce({}) do |memo, pn| |
| 27 | + if @tracked[pn].present? |
| 28 | + # This means this property is a string and we want to return unread chunks. |
| 29 | + if @property_cursors[pn].present? |
| 30 | + unread = @tracked[pn][@property_cursors[pn]..] |
57 | 31 |
|
58 | | - @escape = false |
59 | | - elsif char == "\\" |
60 | | - @escape = true |
61 | | - elsif char == "\"" |
62 | | - @state = :awaiting_key_or_end |
| 32 | + memo[pn] = unread if unread.present? |
| 33 | + @property_cursors[pn] = @tracked[pn].length |
63 | 34 | else |
64 | | - @full_output[@current_key.to_sym] << char |
65 | | - @last_chunk_output[@current_key.to_sym] << char |
| 35 | + # Ints and bools are always returned as is. |
| 36 | + memo[pn] = @tracked[pn] |
66 | 37 | end |
67 | | - when :awaiting_key_or_end |
68 | | - @state = :awaiting_key if char == "," |
69 | | - # End of object or whitespace ignored here |
70 | | - else |
71 | | - next |
72 | 38 | end |
| 39 | + |
| 40 | + memo |
73 | 41 | end |
74 | 42 | end |
75 | 43 |
|
76 | | - private |
77 | | - |
78 | | - def unescape_char(char) |
79 | | - chars = { |
80 | | - '"' => '"', |
81 | | - '\\' => '\\', |
82 | | - "/" => "/", |
83 | | - "b" => "\b", |
84 | | - "f" => "\f", |
85 | | - "n" => "\n", |
86 | | - "r" => "\r", |
87 | | - "t" => "\t", |
88 | | - } |
| 44 | + def notify_progress(key, value) |
| 45 | + key_sym = key.to_sym |
| 46 | + return if !@property_names.include?(key_sym) |
89 | 47 |
|
90 | | - chars[char] || char |
| 48 | + @tracked[key_sym] = value |
91 | 49 | end |
92 | 50 | end |
93 | 51 | end |
|
0 commit comments