Skip to content

Commit 04f2d3c

Browse files
feat(client): accept InputStream and Path for file params (#277)
1 parent c9a19c3 commit 04f2d3c

15 files changed

+422
-83
lines changed

openai-java-core/src/main/kotlin/com/openai/core/ObjectMappers.kt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,43 @@
33
package com.openai.core
44

55
import com.fasterxml.jackson.annotation.JsonInclude
6+
import com.fasterxml.jackson.core.JsonGenerator
67
import com.fasterxml.jackson.databind.DeserializationFeature
78
import com.fasterxml.jackson.databind.SerializationFeature
9+
import com.fasterxml.jackson.databind.SerializerProvider
810
import com.fasterxml.jackson.databind.cfg.CoercionAction.Fail
911
import com.fasterxml.jackson.databind.cfg.CoercionInputShape.Integer
1012
import com.fasterxml.jackson.databind.json.JsonMapper
13+
import com.fasterxml.jackson.databind.module.SimpleModule
1114
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
1215
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
1316
import com.fasterxml.jackson.module.kotlin.jacksonMapperBuilder
17+
import java.io.InputStream
1418

1519
fun jsonMapper(): JsonMapper =
1620
jacksonMapperBuilder()
1721
.addModule(Jdk8Module())
1822
.addModule(JavaTimeModule())
23+
.addModule(SimpleModule().addSerializer(InputStreamJsonSerializer))
1924
.serializationInclusion(JsonInclude.Include.NON_ABSENT)
2025
.disable(DeserializationFeature.ADJUST_DATES_TO_CONTEXT_TIME_ZONE)
2126
.disable(SerializationFeature.FLUSH_AFTER_WRITE_VALUE)
2227
.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
2328
.disable(SerializationFeature.WRITE_DURATIONS_AS_TIMESTAMPS)
2429
.withCoercionConfig(String::class.java) { it.setCoercion(Integer, Fail) }
2530
.build()
31+
32+
private object InputStreamJsonSerializer : BaseSerializer<InputStream>(InputStream::class) {
33+
34+
override fun serialize(
35+
value: InputStream?,
36+
gen: JsonGenerator?,
37+
serializers: SerializerProvider?,
38+
) {
39+
if (value == null) {
40+
gen?.writeNull()
41+
} else {
42+
value.use { gen?.writeBinary(it.readBytes()) }
43+
}
44+
}
45+
}

openai-java-core/src/main/kotlin/com/openai/core/Values.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.node.JsonNodeType.POJO
2727
import com.fasterxml.jackson.databind.node.JsonNodeType.STRING
2828
import com.fasterxml.jackson.databind.ser.std.NullSerializer
2929
import com.openai.errors.OpenAIInvalidDataException
30+
import java.io.InputStream
3031
import java.util.Objects
3132
import java.util.Optional
3233

@@ -508,7 +509,10 @@ private constructor(
508509
return MultipartField(
509510
value,
510511
contentType
511-
?: if (value is KnownValue && value.value is ByteArray)
512+
?: if (
513+
value is KnownValue &&
514+
(value.value is InputStream || value.value is ByteArray)
515+
)
512516
"application/octet-stream"
513517
else "text/plain; charset=utf-8",
514518
filename,

openai-java-core/src/main/kotlin/com/openai/core/http/HttpRequestBodies.kt

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import com.fasterxml.jackson.databind.json.JsonMapper
99
import com.fasterxml.jackson.databind.node.JsonNodeType
1010
import com.openai.core.MultipartField
1111
import com.openai.errors.OpenAIInvalidDataException
12+
import java.io.ByteArrayInputStream
13+
import java.io.InputStream
1214
import java.io.OutputStream
1315
import kotlin.jvm.optionals.getOrNull
1416
import org.apache.hc.client5.http.entity.mime.MultipartEntityBuilder
@@ -41,8 +43,18 @@ internal fun multipartFormData(
4143
MultipartEntityBuilder.create()
4244
.apply {
4345
fields.forEach { (name, field) ->
44-
val node = jsonMapper.valueToTree<JsonNode>(field.value)
45-
serializePart(name, node).forEach { (name, bytes) ->
46+
val knownValue = field.value.asKnown().getOrNull()
47+
val parts =
48+
if (knownValue is InputStream) {
49+
// Read directly from the `InputStream` instead of reading it all
50+
// into memory due to the `jsonMapper` serialization below.
51+
sequenceOf(name to knownValue)
52+
} else {
53+
val node = jsonMapper.valueToTree<JsonNode>(field.value)
54+
serializePart(name, node)
55+
}
56+
57+
parts.forEach { (name, bytes) ->
4658
addBinaryBody(
4759
name,
4860
bytes,
@@ -55,16 +67,19 @@ internal fun multipartFormData(
5567
.build()
5668
}
5769

58-
private fun serializePart(name: String, node: JsonNode): Sequence<Pair<String, ByteArray>> =
70+
private fun serializePart(
71+
name: String,
72+
node: JsonNode,
73+
): Sequence<Pair<String, InputStream>> =
5974
when (node.nodeType) {
6075
JsonNodeType.MISSING,
6176
JsonNodeType.NULL -> emptySequence()
62-
JsonNodeType.BINARY -> sequenceOf(name to node.binaryValue())
63-
JsonNodeType.STRING -> sequenceOf(name to node.textValue().toByteArray())
77+
JsonNodeType.BINARY -> sequenceOf(name to ByteArrayInputStream(node.binaryValue()))
78+
JsonNodeType.STRING -> sequenceOf(name to node.textValue().toInputStream())
6479
JsonNodeType.BOOLEAN ->
65-
sequenceOf(name to node.booleanValue().toString().toByteArray())
80+
sequenceOf(name to node.booleanValue().toString().toInputStream())
6681
JsonNodeType.NUMBER ->
67-
sequenceOf(name to node.numberValue().toString().toByteArray())
82+
sequenceOf(name to node.numberValue().toString().toInputStream())
6883
JsonNodeType.ARRAY ->
6984
node.elements().asSequence().flatMap { element ->
7085
serializePart("$name[]", element)
@@ -78,6 +93,8 @@ internal fun multipartFormData(
7893
throw OpenAIInvalidDataException("Unexpected JsonNode type: ${node.nodeType}")
7994
}
8095

96+
private fun String.toInputStream(): InputStream = ByteArrayInputStream(toByteArray())
97+
8198
override fun writeTo(outputStream: OutputStream) = entity.writeTo(outputStream)
8299

83100
override fun contentType(): String = entity.contentType

openai-java-core/src/main/kotlin/com/openai/models/AudioTranscriptionCreateParams.kt

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@ import com.openai.core.http.Headers
1414
import com.openai.core.http.QueryParams
1515
import com.openai.core.toImmutable
1616
import com.openai.errors.OpenAIInvalidDataException
17+
import java.io.ByteArrayInputStream
18+
import java.io.InputStream
19+
import java.nio.file.Path
1720
import java.util.Objects
1821
import java.util.Optional
22+
import kotlin.io.path.inputStream
23+
import kotlin.io.path.name
1924

2025
/** Transcribes audio into the input language. */
2126
class AudioTranscriptionCreateParams
@@ -29,7 +34,7 @@ private constructor(
2934
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4,
3035
* mpeg, mpga, m4a, ogg, wav, or webm.
3136
*/
32-
fun file(): ByteArray = body.file()
37+
fun file(): InputStream = body.file()
3338

3439
/**
3540
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -78,7 +83,7 @@ private constructor(
7883
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4,
7984
* mpeg, mpga, m4a, ogg, wav, or webm.
8085
*/
81-
fun _file(): MultipartField<ByteArray> = body._file()
86+
fun _file(): MultipartField<InputStream> = body._file()
8287

8388
/**
8489
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -148,7 +153,7 @@ private constructor(
148153
class Body
149154
@JsonCreator
150155
private constructor(
151-
private val file: MultipartField<ByteArray>,
156+
private val file: MultipartField<InputStream>,
152157
private val model: MultipartField<AudioModel>,
153158
private val language: MultipartField<String>,
154159
private val prompt: MultipartField<String>,
@@ -161,7 +166,7 @@ private constructor(
161166
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
162167
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
163168
*/
164-
fun file(): ByteArray = file.value.getRequired("file")
169+
fun file(): InputStream = file.value.getRequired("file")
165170

166171
/**
167172
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -214,7 +219,7 @@ private constructor(
214219
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
215220
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
216221
*/
217-
fun _file(): MultipartField<ByteArray> = file
222+
fun _file(): MultipartField<InputStream> = file
218223

219224
/**
220225
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -296,7 +301,7 @@ private constructor(
296301
/** A builder for [Body]. */
297302
class Builder internal constructor() {
298303

299-
private var file: MultipartField<ByteArray>? = null
304+
private var file: MultipartField<InputStream>? = null
300305
private var model: MultipartField<AudioModel>? = null
301306
private var language: MultipartField<String> = MultipartField.of(null)
302307
private var prompt: MultipartField<String> = MultipartField.of(null)
@@ -321,13 +326,31 @@ private constructor(
321326
* The audio file object (not file name) to transcribe, in one of these formats: flac,
322327
* mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
323328
*/
324-
fun file(file: ByteArray) = file(MultipartField.of(file))
329+
fun file(file: InputStream) = file(MultipartField.of(file))
325330

326331
/**
327332
* The audio file object (not file name) to transcribe, in one of these formats: flac,
328333
* mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
329334
*/
330-
fun file(file: MultipartField<ByteArray>) = apply { this.file = file }
335+
fun file(file: MultipartField<InputStream>) = apply { this.file = file }
336+
337+
/**
338+
* The audio file object (not file name) to transcribe, in one of these formats: flac,
339+
* mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
340+
*/
341+
fun file(file: ByteArray) = file(ByteArrayInputStream(file))
342+
343+
/**
344+
* The audio file object (not file name) to transcribe, in one of these formats: flac,
345+
* mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
346+
*/
347+
fun file(file: Path) =
348+
file(
349+
MultipartField.builder<InputStream>()
350+
.value(file.inputStream())
351+
.filename(file.name)
352+
.build()
353+
)
331354

332355
/**
333356
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper
@@ -506,6 +529,18 @@ private constructor(
506529
additionalQueryParams = audioTranscriptionCreateParams.additionalQueryParams.toBuilder()
507530
}
508531

532+
/**
533+
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
534+
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
535+
*/
536+
fun file(file: InputStream) = apply { body.file(file) }
537+
538+
/**
539+
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
540+
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
541+
*/
542+
fun file(file: MultipartField<InputStream>) = apply { body.file(file) }
543+
509544
/**
510545
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
511546
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
@@ -516,7 +551,7 @@ private constructor(
516551
* The audio file object (not file name) to transcribe, in one of these formats: flac, mp3,
517552
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
518553
*/
519-
fun file(file: MultipartField<ByteArray>) = apply { body.file(file) }
554+
fun file(file: Path) = apply { body.file(file) }
520555

521556
/**
522557
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2

openai-java-core/src/main/kotlin/com/openai/models/AudioTranslationCreateParams.kt

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@ import com.openai.core.checkRequired
1010
import com.openai.core.http.Headers
1111
import com.openai.core.http.QueryParams
1212
import com.openai.core.toImmutable
13+
import java.io.ByteArrayInputStream
14+
import java.io.InputStream
15+
import java.nio.file.Path
1316
import java.util.Objects
1417
import java.util.Optional
18+
import kotlin.io.path.inputStream
19+
import kotlin.io.path.name
1520

1621
/** Translates audio into English. */
1722
class AudioTranslationCreateParams
@@ -25,7 +30,7 @@ private constructor(
2530
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
2631
* mpeg, mpga, m4a, ogg, wav, or webm.
2732
*/
28-
fun file(): ByteArray = body.file()
33+
fun file(): InputStream = body.file()
2934

3035
/**
3136
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -58,7 +63,7 @@ private constructor(
5863
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
5964
* mpeg, mpga, m4a, ogg, wav, or webm.
6065
*/
61-
fun _file(): MultipartField<ByteArray> = body._file()
66+
fun _file(): MultipartField<InputStream> = body._file()
6267

6368
/**
6469
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -110,7 +115,7 @@ private constructor(
110115
class Body
111116
@JsonCreator
112117
private constructor(
113-
private val file: MultipartField<ByteArray>,
118+
private val file: MultipartField<InputStream>,
114119
private val model: MultipartField<AudioModel>,
115120
private val prompt: MultipartField<String>,
116121
private val responseFormat: MultipartField<AudioResponseFormat>,
@@ -121,7 +126,7 @@ private constructor(
121126
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
122127
* mpeg, mpga, m4a, ogg, wav, or webm.
123128
*/
124-
fun file(): ByteArray = file.value.getRequired("file")
129+
fun file(): InputStream = file.value.getRequired("file")
125130

126131
/**
127132
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -157,7 +162,7 @@ private constructor(
157162
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
158163
* mpeg, mpga, m4a, ogg, wav, or webm.
159164
*/
160-
fun _file(): MultipartField<ByteArray> = file
165+
fun _file(): MultipartField<InputStream> = file
161166

162167
/**
163168
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2
@@ -221,7 +226,7 @@ private constructor(
221226
/** A builder for [Body]. */
222227
class Builder internal constructor() {
223228

224-
private var file: MultipartField<ByteArray>? = null
229+
private var file: MultipartField<InputStream>? = null
225230
private var model: MultipartField<AudioModel>? = null
226231
private var prompt: MultipartField<String> = MultipartField.of(null)
227232
private var responseFormat: MultipartField<AudioResponseFormat> =
@@ -241,13 +246,31 @@ private constructor(
241246
* The audio file object (not file name) translate, in one of these formats: flac, mp3,
242247
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
243248
*/
244-
fun file(file: ByteArray) = file(MultipartField.of(file))
249+
fun file(file: InputStream) = file(MultipartField.of(file))
245250

246251
/**
247252
* The audio file object (not file name) translate, in one of these formats: flac, mp3,
248253
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
249254
*/
250-
fun file(file: MultipartField<ByteArray>) = apply { this.file = file }
255+
fun file(file: MultipartField<InputStream>) = apply { this.file = file }
256+
257+
/**
258+
* The audio file object (not file name) translate, in one of these formats: flac, mp3,
259+
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
260+
*/
261+
fun file(file: ByteArray) = file(ByteArrayInputStream(file))
262+
263+
/**
264+
* The audio file object (not file name) translate, in one of these formats: flac, mp3,
265+
* mp4, mpeg, mpga, m4a, ogg, wav, or webm.
266+
*/
267+
fun file(file: Path) =
268+
file(
269+
MultipartField.builder<InputStream>()
270+
.value(file.inputStream())
271+
.filename(file.name)
272+
.build()
273+
)
251274

252275
/**
253276
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper
@@ -375,6 +398,18 @@ private constructor(
375398
additionalQueryParams = audioTranslationCreateParams.additionalQueryParams.toBuilder()
376399
}
377400

401+
/**
402+
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
403+
* mpeg, mpga, m4a, ogg, wav, or webm.
404+
*/
405+
fun file(file: InputStream) = apply { body.file(file) }
406+
407+
/**
408+
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
409+
* mpeg, mpga, m4a, ogg, wav, or webm.
410+
*/
411+
fun file(file: MultipartField<InputStream>) = apply { body.file(file) }
412+
378413
/**
379414
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
380415
* mpeg, mpga, m4a, ogg, wav, or webm.
@@ -385,7 +420,7 @@ private constructor(
385420
* The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4,
386421
* mpeg, mpga, m4a, ogg, wav, or webm.
387422
*/
388-
fun file(file: MultipartField<ByteArray>) = apply { body.file(file) }
423+
fun file(file: Path) = apply { body.file(file) }
389424

390425
/**
391426
* ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2

0 commit comments

Comments
 (0)