Skip to content

Commit 21c9e2e

Browse files
authored
Handle Option correctly (#266)
* Handle Option correctly Signed-off-by: Hongxin Liang <[email protected]> * Clearer comment Signed-off-by: Hongxin Liang <[email protected]> * IT Signed-off-by: Hongxin Liang <[email protected]> * Early return of None Signed-off-by: Hongxin Liang <[email protected]> * Enrich IT workflow Signed-off-by: Hongxin Liang <[email protected]> * Resource Signed-off-by: Hongxin Liang <[email protected]> --------- Signed-off-by: Hongxin Liang <[email protected]>
1 parent 8ec7a78 commit 21c9e2e

File tree

6 files changed

+98
-55
lines changed

6 files changed

+98
-55
lines changed

flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ org.flyte.examples.flytekitscala.GreetTask
44
org.flyte.examples.flytekitscala.AddQuestionTask
55
org.flyte.examples.flytekitscala.NoInputsTask
66
org.flyte.examples.flytekitscala.NestedIOTask
7+
org.flyte.examples.flytekitscala.NestedIOTaskNoop

flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/LaunchPlanRegistry.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,19 @@ class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
7373
6.toDouble,
7474
"hello",
7575
List("1", "2"),
76-
List(NestedNested(7.toDouble, NestedNestedNested("world"))),
76+
List(NestedNested(7.toDouble, Some(NestedNestedNested("world")))),
7777
Map("1" -> "1", "2" -> "2"),
78-
Map("foo" -> NestedNested(7.toDouble, NestedNestedNested("world"))),
78+
Map(
79+
"foo" -> NestedNested(
80+
7.toDouble,
81+
Some(NestedNestedNested("world"))
82+
)
83+
),
7984
Some(false),
8085
None,
8186
Some(List("3", "4")),
8287
Some(Map("3" -> "3", "4" -> "4")),
83-
NestedNested(7.toDouble, NestedNestedNested("world"))
88+
NestedNested(7.toDouble, Some(NestedNestedNested("world")))
8489
)
8590
)
8691
)

flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOTask.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.flyte.flytekitscala.{
2424
}
2525

2626
case class NestedNestedNested(string: String)
27-
case class NestedNested(double: Double, nested: NestedNestedNested)
27+
case class NestedNested(double: Double, nested: Option[NestedNestedNested])
2828
case class Nested(
2929
boolean: Boolean,
3030
byte: Byte,
@@ -57,9 +57,6 @@ case class NestedIOTaskOutput(
5757
generic: SdkBindingData[Nested]
5858
)
5959

60-
/** Example Flyte task that takes a name as the input and outputs a simple
61-
* greeting message.
62-
*/
6360
class NestedIOTask
6461
extends SdkRunnableTask[
6562
NestedIOTaskInput,
@@ -69,17 +66,21 @@ class NestedIOTask
6966
SdkScalaType[NestedIOTaskOutput]
7067
) {
7168

72-
/** Defines task behavior. This task takes a name as the input, wraps it in a
73-
* welcome message, and outputs the message.
74-
*
75-
* @param input
76-
* the name of the person to be greeted
77-
* @return
78-
* the welcome message
79-
*/
8069
override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
8170
NestedIOTaskOutput(
8271
input.name,
8372
input.generic
8473
)
8574
}
75+
76+
class NestedIOTaskNoop
77+
extends SdkRunnableTask[
78+
NestedIOTaskOutput,
79+
NestedIOTaskOutput
80+
](
81+
SdkScalaType[NestedIOTaskOutput],
82+
SdkScalaType[NestedIOTaskOutput]
83+
) {
84+
85+
override def run(input: NestedIOTaskOutput): NestedIOTaskOutput = input
86+
}

flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/NestedIOWorkflow.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class NestedIOWorkflow
3232
builder: SdkScalaWorkflowBuilder,
3333
input: NestedIOTaskInput
3434
): Unit = {
35-
builder.apply(new NestedIOTask(), input)
35+
val output = builder.apply(new NestedIOTask(), input)
36+
builder.apply(new NestedIOTaskNoop(), output.getOutputs)
3637
}
3738
}

flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ import org.flyte.flytekitscala.SdkLiteralTypes.{
4949
}
5050

5151
// The constructor is reflectedly invoked so it cannot be an inner class
52-
case class ScalarNested(foo: String, bar: String)
52+
case class ScalarNested(
53+
foo: String,
54+
bar: Option[String],
55+
nestedNested: Option[ScalarNestedNested]
56+
)
57+
case class ScalarNestedNested(foo: String, bar: Option[String])
5358

5459
class SdkScalaTypeTest {
5560

@@ -178,7 +183,15 @@ class SdkScalaTypeTest {
178183
Struct.of(
179184
Map(
180185
"foo" -> Struct.Value.ofStringValue("foo"),
181-
"bar" -> Struct.Value.ofStringValue("bar")
186+
"bar" -> Struct.Value.ofNullValue(),
187+
"nestedNested" -> Struct.Value.ofStructValue(
188+
Struct.of(
189+
Map(
190+
"foo" -> Struct.Value.ofStringValue("foo"),
191+
"bar" -> Struct.Value.ofStringValue("bar")
192+
).asJava
193+
)
194+
)
182195
).asJava
183196
)
184197
)
@@ -196,7 +209,11 @@ class SdkScalaTypeTest {
196209
blob = SdkBindingDataFactory.of(blob),
197210
generic = SdkBindingDataFactory.of(
198211
SdkLiteralTypes.generics(),
199-
ScalarNested("foo", "bar")
212+
ScalarNested(
213+
"foo",
214+
None,
215+
Some(ScalarNestedNested("foo", Some("bar")))
216+
)
200217
)
201218
)
202219

@@ -218,7 +235,11 @@ class SdkScalaTypeTest {
218235
blob = SdkBindingDataFactory.of(blob),
219236
generic = SdkBindingDataFactory.of(
220237
SdkLiteralTypes.generics(),
221-
ScalarNested("foo", "bar")
238+
ScalarNested(
239+
"foo",
240+
Some("bar"),
241+
Some(ScalarNestedNested("foo", Some("bar")))
242+
)
222243
)
223244
)
224245

@@ -245,7 +266,15 @@ class SdkScalaTypeTest {
245266
Struct.of(
246267
Map(
247268
"foo" -> Struct.Value.ofStringValue("foo"),
248-
"bar" -> Struct.Value.ofStringValue("bar")
269+
"bar" -> Struct.Value.ofStringValue("bar"),
270+
"nestedNested" -> Struct.Value.ofStructValue(
271+
Struct.of(
272+
Map(
273+
"foo" -> Struct.Value.ofStringValue("foo"),
274+
"bar" -> Struct.Value.ofStringValue("bar")
275+
).asJava
276+
)
277+
)
249278
).asJava
250279
)
251280
)
@@ -285,7 +314,11 @@ class SdkScalaTypeTest {
285314
blob = SdkBindingDataFactory.of(blob),
286315
generic = SdkBindingDataFactory.of(
287316
SdkLiteralTypes.generics(),
288-
ScalarNested("foo", "bar")
317+
ScalarNested(
318+
"foo",
319+
Some("bar"),
320+
Some(ScalarNestedNested("foo", Some("bar")))
321+
)
289322
)
290323
)
291324

@@ -301,7 +334,11 @@ class SdkScalaTypeTest {
301334
"blob" -> SdkBindingDataFactory.of(blob),
302335
"generic" -> SdkBindingDataFactory.of(
303336
SdkLiteralTypes.generics[ScalarNested](),
304-
ScalarNested("foo", "bar")
337+
ScalarNested(
338+
"foo",
339+
Some("bar"),
340+
Some(ScalarNestedNested("foo", Some("bar")))
341+
)
305342
)
306343
).asJava
307344

flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -297,41 +297,39 @@ object SdkLiteralTypes {
297297
): S = {
298298
val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)
299299

300-
def valueToParamValue(value: Any, param: Symbol): Any = {
301-
def valueToParamValue0(value: Any, param: Symbol): Any = {
302-
if (param.typeSignature =:= typeOf[Byte]) {
303-
value.asInstanceOf[Double].toByte
304-
} else if (param.typeSignature =:= typeOf[Short]) {
305-
value.asInstanceOf[Double].toShort
306-
} else if (param.typeSignature =:= typeOf[Int]) {
307-
value.asInstanceOf[Double].toInt
308-
} else if (param.typeSignature =:= typeOf[Long]) {
309-
value.asInstanceOf[Double].toLong
310-
} else if (param.typeSignature =:= typeOf[Float]) {
311-
value.asInstanceOf[Double].toFloat
312-
} else if (param.typeSignature <:< typeOf[Product]) {
313-
val typeTag = createTypeTag(param.typeSignature)
314-
val classTag = ClassTag(
315-
typeTag.mirror.runtimeClass(param.typeSignature)
316-
)
317-
mapToProduct(value.asInstanceOf[Map[String, Any]])(
318-
typeTag,
319-
classTag
320-
)
300+
def valueToParamValue(value: Any, tpe: Type): Any = {
301+
if (tpe =:= typeOf[Byte]) {
302+
value.asInstanceOf[Double].toByte
303+
} else if (tpe =:= typeOf[Short]) {
304+
value.asInstanceOf[Double].toShort
305+
} else if (tpe =:= typeOf[Int]) {
306+
value.asInstanceOf[Double].toInt
307+
} else if (tpe =:= typeOf[Long]) {
308+
value.asInstanceOf[Double].toLong
309+
} else if (tpe =:= typeOf[Float]) {
310+
value.asInstanceOf[Double].toFloat
311+
} else if (tpe <:< typeOf[Option[Any]]) { // this has to be before Product check because Option is a Product
312+
if (value == None) { // None is used to represent Struct.Value.Kind.NULL_VALUE when converting struct to map
313+
None
321314
} else {
322-
value
323-
}
324-
}
325-
326-
if (param.typeSignature <:< typeOf[Option[Any]]) {
327-
Some(
328-
valueToParamValue0(
329-
value,
330-
param.typeSignature.dealias.typeArgs.head.typeSymbol
315+
Some(
316+
valueToParamValue(
317+
value,
318+
tpe.dealias.typeArgs.head
319+
)
331320
)
321+
}
322+
} else if (tpe <:< typeOf[Product]) {
323+
val typeTag = createTypeTag(tpe)
324+
val classTag = ClassTag(
325+
typeTag.mirror.runtimeClass(tpe)
326+
)
327+
mapToProduct(value.asInstanceOf[Map[String, Any]])(
328+
typeTag,
329+
classTag
332330
)
333331
} else {
334-
valueToParamValue0(value, param)
332+
value
335333
}
336334
}
337335

@@ -371,7 +369,7 @@ object SdkLiteralTypes {
371369
s"Map is missing required parameter named $paramName"
372370
)
373371
)
374-
valueToParamValue(value, param)
372+
valueToParamValue(value, param.typeSignature.dealias)
375373
})
376374

377375
constructorMirror(constructorArgs: _*).asInstanceOf[S]

0 commit comments

Comments
 (0)