Skip to content

Commit 8ec7a78

Browse files
authored
Minor clean up of struct (#265)
Signed-off-by: Hongxin Liang <[email protected]>
1 parent bd369ca commit 8ec7a78

File tree

9 files changed

+94
-32
lines changed

9 files changed

+94
-32
lines changed

flytekit-api/src/main/java/org/flyte/api/v1/BlobType.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
/** Defines type behavior for blob objects. */
2222
@AutoValue
2323
public abstract class BlobType {
24+
25+
public static final BlobType DEFAULT =
26+
BlobType.builder().dimensionality(BlobDimensionality.SINGLE).format("").build();
27+
2428
public enum BlobDimensionality {
2529
SINGLE,
2630
MULTIPART

flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.flyte.api.v1.Blob;
2828
import org.flyte.api.v1.BlobMetadata;
2929
import org.flyte.api.v1.BlobType;
30-
import org.flyte.api.v1.BlobType.BlobDimensionality;
3130
import org.flyte.examples.AllInputsTask.AutoAllInputsOutput;
3231
import org.flyte.examples.AllInputsTask.Nested;
3332
import org.flyte.flytekit.SdkBindingData;
@@ -66,14 +65,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput)
6665
SdkBindingDataFactory.of(
6766
Blob.builder()
6867
.uri("file://test/test.csv")
69-
.metadata(
70-
BlobMetadata.builder()
71-
.type(
72-
BlobType.builder()
73-
.format("")
74-
.dimensionality(BlobDimensionality.SINGLE)
75-
.build())
76-
.build())
68+
.metadata(BlobMetadata.builder().type(BlobType.DEFAULT).build())
7769
.build()),
7870
SdkBindingDataFactory.of(
7971
JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")),

flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.util.Map;
3333
import org.flyte.api.v1.Blob;
3434
import org.flyte.api.v1.BlobType;
35-
import org.flyte.api.v1.BlobType.BlobDimensionality;
3635
import org.flyte.api.v1.Variable;
3736
import org.flyte.flytekit.SdkBindingData;
3837
import org.flyte.flytekit.SdkLiteralType;
@@ -172,8 +171,7 @@ private SdkLiteralType<?> toLiteralType(
172171
// fixme: create blob type from annotation, or rethink how we could offer the offloaded data
173172
// feature
174173
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
175-
return SdkLiteralTypes.blobs(
176-
BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build());
174+
return SdkLiteralTypes.blobs(BlobType.DEFAULT);
177175
}
178176
try {
179177
return JacksonSdkLiteralType.of(type);

flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@
5656

5757
public class JacksonSdkTypeTest {
5858

59-
private static final BlobType BLOB_TYPE =
60-
BlobType.builder().format("").dimensionality(BlobType.BlobDimensionality.SINGLE).build();
59+
private static final BlobType BLOB_TYPE = BlobType.DEFAULT;
6160

6261
private static final Blob BLOB =
6362
Blob.builder()

flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class TestOfReturnsProperTypeProvider extends ArgumentsProvider {
6666
Arguments.of(booleans(), of[Boolean]()),
6767
Arguments.of(datetimes(), of[Instant]()),
6868
Arguments.of(durations(), of[Duration]()),
69+
Arguments.of(blobs(BlobType.DEFAULT), of[Blob]()),
70+
Arguments.of(generics(), of[ScalarNested]()),
6971
Arguments.of(collections(integers()), of[List[Long]]()),
7072
Arguments.of(collections(floats()), of[List[Double]]()),
7173
Arguments.of(collections(strings()), of[List[String]]()),

flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.flyte.api.v1.{
2929
Primitive,
3030
Scalar,
3131
SimpleType,
32+
Struct,
3233
Variable
3334
}
3435
import org.flyte.flytekit.{
@@ -40,17 +41,32 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows}
4041
import org.junit.jupiter.api.Test
4142
import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested}
4243
import org.flyte.flytekit.jackson.JacksonSdkLiteralType
43-
import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings}
44+
import org.flyte.flytekitscala.SdkLiteralTypes.{
45+
blobs,
46+
collections,
47+
maps,
48+
strings
49+
}
50+
51+
// The constructor is reflectedly invoked so it cannot be an inner class
52+
case class ScalarNested(foo: String, bar: String)
4453

4554
class SdkScalaTypeTest {
4655

56+
private val blob = Blob.builder
57+
.metadata(BlobMetadata.builder.`type`(BlobType.DEFAULT).build)
58+
.uri("file://test")
59+
.build
60+
4761
case class ScalarInput(
4862
string: SdkBindingData[String],
4963
integer: SdkBindingData[Long],
5064
float: SdkBindingData[Double],
5165
boolean: SdkBindingData[Boolean],
5266
datetime: SdkBindingData[Instant],
53-
duration: SdkBindingData[Duration]
67+
duration: SdkBindingData[Duration],
68+
blob: SdkBindingData[Blob],
69+
generic: SdkBindingData[ScalarNested]
5470
)
5571

5672
case class CollectionInput(
@@ -116,7 +132,13 @@ class SdkScalaTypeTest {
116132
"float" -> createVar(SimpleType.FLOAT),
117133
"boolean" -> createVar(SimpleType.BOOLEAN),
118134
"datetime" -> createVar(SimpleType.DATETIME),
119-
"duration" -> createVar(SimpleType.DURATION)
135+
"duration" -> createVar(SimpleType.DURATION),
136+
"blob" -> Variable
137+
.builder()
138+
.literalType(LiteralType.ofBlobType(BlobType.DEFAULT))
139+
.description("")
140+
.build(),
141+
"generic" -> createVar(SimpleType.STRUCT)
120142
).asJava
121143

122144
val output = SdkScalaType[ScalarInput].getVariableMap
@@ -149,6 +171,17 @@ class SdkScalaTypeTest {
149171
),
150172
"duration" -> Literal.ofScalar(
151173
Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456)))
174+
),
175+
"blob" -> Literal.ofScalar(Scalar.ofBlob(blob)),
176+
"generic" -> Literal.ofScalar(
177+
Scalar.ofGeneric(
178+
Struct.of(
179+
Map(
180+
"foo" -> Struct.Value.ofStringValue("foo"),
181+
"bar" -> Struct.Value.ofStringValue("bar")
182+
).asJava
183+
)
184+
)
152185
)
153186
).asJava
154187

@@ -159,7 +192,12 @@ class SdkScalaTypeTest {
159192
float = SdkBindingDataFactory.of(42.0),
160193
boolean = SdkBindingDataFactory.of(true),
161194
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
162-
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
195+
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
196+
blob = SdkBindingDataFactory.of(blob),
197+
generic = SdkBindingDataFactory.of(
198+
SdkLiteralTypes.generics(),
199+
ScalarNested("foo", "bar")
200+
)
163201
)
164202

165203
val output = SdkScalaType[ScalarInput].fromLiteralMap(input)
@@ -176,7 +214,12 @@ class SdkScalaTypeTest {
176214
float = SdkBindingDataFactory.of(42.0),
177215
boolean = SdkBindingDataFactory.of(true),
178216
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
179-
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
217+
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
218+
blob = SdkBindingDataFactory.of(blob),
219+
generic = SdkBindingDataFactory.of(
220+
SdkLiteralTypes.generics(),
221+
ScalarNested("foo", "bar")
222+
)
180223
)
181224

182225
val expected = Map(
@@ -195,6 +238,17 @@ class SdkScalaTypeTest {
195238
),
196239
"duration" -> Literal.ofScalar(
197240
Scalar.ofPrimitive(Primitive.ofDuration(Duration.ofSeconds(123, 456)))
241+
),
242+
"blob" -> Literal.ofScalar(Scalar.ofBlob(blob)),
243+
"generic" -> Literal.ofScalar(
244+
Scalar.ofGeneric(
245+
Struct.of(
246+
Map(
247+
"foo" -> Struct.Value.ofStringValue("foo"),
248+
"bar" -> Struct.Value.ofStringValue("bar")
249+
).asJava
250+
)
251+
)
198252
)
199253
).asJava
200254

@@ -227,7 +281,12 @@ class SdkScalaTypeTest {
227281
float = SdkBindingDataFactory.of(42.0),
228282
boolean = SdkBindingDataFactory.of(true),
229283
datetime = SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
230-
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
284+
duration = SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
285+
blob = SdkBindingDataFactory.of(blob),
286+
generic = SdkBindingDataFactory.of(
287+
SdkLiteralTypes.generics(),
288+
ScalarNested("foo", "bar")
289+
)
231290
)
232291

233292
val output = SdkScalaType[ScalarInput].toSdkBindingMap(input)
@@ -238,7 +297,12 @@ class SdkScalaTypeTest {
238297
"float" -> SdkBindingDataFactory.of(42.0),
239298
"boolean" -> SdkBindingDataFactory.of(true),
240299
"datetime" -> SdkBindingDataFactory.of(Instant.ofEpochMilli(123456L)),
241-
"duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456))
300+
"duration" -> SdkBindingDataFactory.of(Duration.ofSeconds(123, 456)),
301+
"blob" -> SdkBindingDataFactory.of(blob),
302+
"generic" -> SdkBindingDataFactory.of(
303+
SdkLiteralTypes.generics[ScalarNested](),
304+
ScalarNested("foo", "bar")
305+
)
242306
).asJava
243307

244308
assertEquals(expected, output)

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ object SdkBindingDataConverters {
150150
SdkScalaLiteralTypes.strings(),
151151
jf.Function.identity()
152152
)
153-
case SimpleType.STRUCT => ??? // TODO not yet supported
153+
case SimpleType.STRUCT =>
154+
throw new UnsupportedOperationException(
155+
"Converting Scala case class instance to Java object is not supported"
156+
)
154157
case SimpleType.BOOLEAN =>
155158
TypeCastingResult(
156159
SdkScalaLiteralTypes.booleans(),
@@ -239,7 +242,9 @@ object SdkBindingDataConverters {
239242
jf.Function.identity()
240243
)
241244
case SimpleType.STRUCT =>
242-
??? // TODO how to handle? do we support structs already?
245+
throw new UnsupportedOperationException(
246+
"Converting Java object to Scala case class instance is not supported"
247+
)
243248
case SimpleType.BOOLEAN =>
244249
TypeCastingResult(
245250
SdkJavaLiteralTypes.booleans(),

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ object SdkLiteralTypes {
6666
datetimes().asInstanceOf[SdkLiteralType[T]]
6767
case t if t =:= typeOf[Duration] =>
6868
durations().asInstanceOf[SdkLiteralType[T]]
69+
case t if t =:= typeOf[Blob] =>
70+
blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]]
71+
case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) =>
72+
generics().asInstanceOf[SdkLiteralType[T]]
6973

7074
case t if t =:= typeOf[List[Long]] =>
7175
collections(integers()).asInstanceOf[SdkLiteralType[T]]

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ import org.flyte.flytekit.{
3030

3131
import scala.annotation.implicitNotFound
3232
import scala.collection.JavaConverters._
33-
import scala.reflect.{ClassTag, classTag}
34-
import scala.reflect.runtime.universe.{TypeTag, typeOf}
33+
import scala.reflect.ClassTag
34+
import scala.reflect.runtime.universe.TypeTag
3535

3636
/** Type class to map between Flyte `Variable` and `Literal` and Scala case
3737
* classes.
@@ -245,13 +245,7 @@ object SdkScalaType {
245245
// https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype
246246
implicit def blobLiteralType: SdkScalaLiteralType[Blob] =
247247
DelegateLiteralType(
248-
SdkLiteralTypes.blobs(
249-
BlobType
250-
.builder()
251-
.format("")
252-
.dimensionality(BlobDimensionality.SINGLE)
253-
.build()
254-
)
248+
SdkLiteralTypes.blobs(BlobType.DEFAULT)
255249
)
256250

257251
// TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData<java.util.Long>

0 commit comments

Comments
 (0)