Skip to content

Commit 8a32d42

Browse files
committed
Merge remote-tracking branch 'alejandrojnm/master' into release/v1.41.2
2 parents 02e8eae + fa78712 commit 8a32d42

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

chat.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
8181
Detail ImageURLDetail `json:"detail,omitempty"`
8282
}
8383

84+
// ChatMessageFile is a placeholder for file parts in chat messages.
85+
type ChatMessageFile struct {
86+
FileID string `json:"file_id,omitempty"`
87+
FileName string `json:"filename,omitempty"`
88+
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
89+
}
90+
8491
type ChatMessagePartType string
8592

8693
const (
8794
ChatMessagePartTypeText ChatMessagePartType = "text"
8895
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
96+
ChatMessagePartTypeFile ChatMessagePartType = "file"
8997
)
9098

9199
type ChatMessagePart struct {
92100
Type ChatMessagePartType `json:"type,omitempty"`
93101
Text string `json:"text,omitempty"`
94102
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
103+
File *ChatMessageFile `json:"file,omitempty"`
95104
}
96105

97106
type ChatCompletionMessage struct {

chat_test.go

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,14 @@ func TestMultipartChatCompletions(t *testing.T) {
797797
Detail: openai.ImageURLDetailLow,
798798
},
799799
},
800+
{
801+
Type: openai.ChatMessagePartTypeFile,
802+
File: &openai.ChatMessageFile{
803+
FileID: "file-123",
804+
FileName: "test.txt",
805+
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
806+
},
807+
},
800808
},
801809
},
802810
},
@@ -807,7 +815,8 @@ func TestMultipartChatCompletions(t *testing.T) {
807815
func TestMultipartChatMessageSerialization(t *testing.T) {
808816
jsonText := `[{"role":"system","content":"system-message"},` +
809817
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
810-
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
818+
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
819+
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`
811820

812821
var msgs []openai.ChatCompletionMessage
813822
err := json.Unmarshal([]byte(jsonText), &msgs)
@@ -820,7 +829,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
820829
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
821830
t.Errorf("invalid user message: %v", msgs[0])
822831
}
823-
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
832+
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
824833
t.Errorf("invalid user message")
825834
}
826835
parts := msgs[1].MultiContent
@@ -830,6 +839,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
830839
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
831840
t.Errorf("invalid image_url part")
832841
}
842+
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" ||
843+
parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
844+
t.Errorf("invalid file part: %v", parts[2])
845+
}
833846

834847
s, err := json.Marshal(msgs)
835848
if err != nil {
@@ -876,6 +889,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
876889
}
877890
}
878891

892+
func TestChatMessageFile(t *testing.T) {
893+
// Test file part with FileID
894+
filePart := openai.ChatMessagePart{
895+
Type: openai.ChatMessagePartTypeFile,
896+
File: &openai.ChatMessageFile{
897+
FileID: "file-abc123",
898+
},
899+
}
900+
901+
// Test serialization
902+
data, err := json.Marshal(filePart)
903+
if err != nil {
904+
t.Fatalf("Expected no error: %s", err)
905+
}
906+
907+
expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
908+
result := strings.ReplaceAll(string(data), " ", "")
909+
if result != expected {
910+
t.Errorf("Expected %s, got %s", expected, result)
911+
}
912+
913+
// Test deserialization
914+
var parsedPart openai.ChatMessagePart
915+
err = json.Unmarshal(data, &parsedPart)
916+
if err != nil {
917+
t.Fatalf("Expected no error: %s", err)
918+
}
919+
920+
if parsedPart.Type != openai.ChatMessagePartTypeFile {
921+
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
922+
}
923+
if parsedPart.File == nil {
924+
t.Fatal("Expected File to be non-nil")
925+
}
926+
if parsedPart.File.FileID != "file-abc123" {
927+
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
928+
}
929+
930+
// Test file part with all fields
931+
filePartComplete := openai.ChatMessagePart{
932+
Type: openai.ChatMessagePartTypeFile,
933+
File: &openai.ChatMessageFile{
934+
FileID: "file-xyz789",
935+
FileName: "document.pdf",
936+
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
937+
},
938+
}
939+
940+
data, err = json.Marshal(filePartComplete)
941+
if err != nil {
942+
t.Fatalf("Expected no error: %s", err)
943+
}
944+
945+
expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
946+
result = strings.ReplaceAll(string(data), " ", "")
947+
if result != expected {
948+
t.Errorf("Expected %s, got %s", expected, result)
949+
}
950+
951+
// Test deserialization of complete file part
952+
var parsedCompleteFile openai.ChatMessagePart
953+
err = json.Unmarshal(data, &parsedCompleteFile)
954+
if err != nil {
955+
t.Fatalf("Expected no error: %s", err)
956+
}
957+
958+
if parsedCompleteFile.File.FileID != "file-xyz789" {
959+
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
960+
}
961+
if parsedCompleteFile.File.FileName != "document.pdf" {
962+
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
963+
}
964+
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
965+
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
966+
}
967+
}
968+
969+
func TestChatMessagePartTypeConstants(t *testing.T) {
970+
// Test that the new file constant is properly defined
971+
if openai.ChatMessagePartTypeFile != "file" {
972+
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
973+
}
974+
975+
// Test all part type constants
976+
expectedTypes := map[openai.ChatMessagePartType]string{
977+
openai.ChatMessagePartTypeText: "text",
978+
openai.ChatMessagePartTypeImageURL: "image_url",
979+
openai.ChatMessagePartTypeFile: "file",
980+
}
981+
982+
for constant, expected := range expectedTypes {
983+
if string(constant) != expected {
984+
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
985+
}
986+
}
987+
}
988+
879989
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
880990
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
881991
var err error

0 commit comments

Comments
 (0)