Skip to content

Commit c18adaf

Browse files
Add new message type for vars serialization
1 parent a6da2d8 commit c18adaf

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2728
import kotlin.reflect.KClass
2829
import kotlin.reflect.full.createType
@@ -86,7 +87,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8687
COMM_CLOSE(CommClose::class),
8788

8889
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
89-
LIST_ERRORS_REPLY(ListErrorsReply::class);
90+
LIST_ERRORS_REPLY(ListErrorsReply::class),
91+
92+
SERIALIZATION_REQUEST(SerializationRequest::class),
93+
SERIALIZATION_REPLY(SerializationReply::class);
9094

9195
// TODO: add custom commands
9296
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -573,6 +577,18 @@ class ListErrorsReply(
573577
val errors: List<ScriptDiagnostic>
574578
) : MessageContent()
575579

580+
@Serializable
581+
class SerializationRequest(
582+
val cellId: Int,
583+
val descriptorsState: Map<String, SerializedVariablesState>
584+
) : MessageContent()
585+
586+
@Serializable
587+
class SerializationReply(
588+
val cellId: Int,
589+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
590+
) : MessageContent()
591+
576592
@Serializable(MessageDataSerializer::class)
577593
data class MessageData(
578594
val header: MessageHeader? = null,

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
319319
}
320320
}
321321
}
322+
is SerializationRequest -> {
323+
GlobalScope.launch(Dispatchers.Default) {
324+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
325+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
326+
}
327+
}
328+
}
322329
is IsCompleteRequest -> {
323330
// We are in console mode, so switch off all the loggers
324331
if (mainLoggerLevel() != Level.OFF) disableLogging()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2626
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
2727
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
2828
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
29+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2930
import org.jetbrains.kotlinx.jupyter.config.catchAll
3031
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3132
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -117,6 +118,8 @@ interface ReplForJupyter {
117118

118119
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
119120

121+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
122+
120123
val homeDir: File?
121124

122125
val currentClasspath: Collection<String>
@@ -515,6 +518,20 @@ class ReplForJupyterImpl(
515518
return ListErrorsResult(args.code, errorsList)
516519
}
517520

521+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
522+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
523+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
524+
}
525+
526+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
527+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
528+
args.descriptorsState.forEach { (name, state) ->
529+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
530+
}
531+
return SerializationReply(args.cellId, resultMap)
532+
}
533+
534+
518535
private fun <T, Args : LockQueueArgs<T>> doWithLock(
519536
args: Args,
520537
queue: LockQueue<T, Args>,
@@ -547,6 +564,12 @@ class ReplForJupyterImpl(
547564
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
548565
LockQueueArgs<ListErrorsResult>
549566

567+
private data class SerializationArgs(
568+
val cellId: Int,
569+
val descriptorsState: Map<String, SerializedVariablesState>,
570+
override val callback: (SerializationReply) -> Unit
571+
) : LockQueueArgs<SerializationReply>
572+
550573
@JvmInline
551574
private value class LockQueue<T, Args : LockQueueArgs<T>>(
552575
private val args: AtomicReference<Args?> = AtomicReference()

src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,4 +911,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
911911
assertEquals("${values++}", state.value)
912912
}
913913
}
914+
915+
@Test
916+
fun testSerializationMessage() {
917+
val res = eval(
918+
"""
919+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
920+
""".trimIndent(),
921+
jupyterId = 1
922+
)
923+
val varsData = res.metadata.evaluatedVariablesState
924+
assertEquals(1, varsData.size)
925+
val listData = varsData["x"]!!
926+
assertTrue(listData.isContainer)
927+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
928+
val propertyName = listData.fieldDescriptor.entries.first().key
929+
930+
runBlocking {
931+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
932+
val data = result.descriptorsState
933+
assertTrue(data.isNotEmpty())
934+
935+
val innerList = data.entries.last().value!!
936+
assertTrue(innerList.isContainer)
937+
var receivedDescriptor = innerList.fieldDescriptor
938+
assertEquals(2, receivedDescriptor.size)
939+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
940+
941+
assertEquals(5, receivedDescriptor.size)
942+
var values = 1
943+
receivedDescriptor.forEach { (name, state) ->
944+
if (name == "size") {
945+
assertFalse(state!!.isContainer)
946+
assertTrue(state!!.fieldDescriptor.isEmpty())
947+
return@forEach
948+
}
949+
val fieldDescriptor = state!!.fieldDescriptor
950+
assertEquals(0, fieldDescriptor.size)
951+
assertTrue(state.isContainer)
952+
assertEquals("${values++}", state.value)
953+
}
954+
}
955+
}
956+
}
914957
}

0 commit comments

Comments
 (0)