Skip to content

Commit a8b19e9

Browse files
Add new message type for vars serialization
1 parent 7e0daf5 commit a8b19e9

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.config.LanguageInfo
2728
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2829
import kotlin.reflect.KClass
@@ -87,7 +88,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8788
COMM_CLOSE(CommClose::class),
8889

8990
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
90-
LIST_ERRORS_REPLY(ListErrorsReply::class);
91+
LIST_ERRORS_REPLY(ListErrorsReply::class),
92+
93+
SERIALIZATION_REQUEST(SerializationRequest::class),
94+
SERIALIZATION_REPLY(SerializationReply::class);
9195

9296
// TODO: add custom commands
9397
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -555,6 +559,18 @@ class ListErrorsReply(
555559
val errors: List<ScriptDiagnostic>
556560
) : MessageContent()
557561

562+
@Serializable
563+
class SerializationRequest(
564+
val cellId: Int,
565+
val descriptorsState: Map<String, SerializedVariablesState>
566+
) : MessageContent()
567+
568+
@Serializable
569+
class SerializationReply(
570+
val cellId: Int,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
572+
) : MessageContent()
573+
558574
@Serializable(MessageDataSerializer::class)
559575
data class MessageData(
560576
val header: MessageHeader? = null,

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
313313
}
314314
}
315315
}
316+
is SerializationRequest -> {
317+
GlobalScope.launch(Dispatchers.Default) {
318+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
319+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
320+
}
321+
}
322+
}
316323
is IsCompleteRequest -> {
317324
// We are in console mode, so switch off all the loggers
318325
if (mainLoggerLevel() != Level.OFF) disableLogging()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2727
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
2828
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
2929
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
30+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
3031
import org.jetbrains.kotlinx.jupyter.config.catchAll
3132
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3233
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -119,6 +120,8 @@ interface ReplForJupyter {
119120

120121
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
121122

123+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
124+
122125
val homeDir: File?
123126

124127
val currentClasspath: Collection<String>
@@ -522,6 +525,20 @@ class ReplForJupyterImpl(
522525
return ListErrorsResult(args.code, errorsList)
523526
}
524527

528+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
529+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
530+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
531+
}
532+
533+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
534+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
535+
args.descriptorsState.forEach { (name, state) ->
536+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
537+
}
538+
return SerializationReply(args.cellId, resultMap)
539+
}
540+
541+
525542
private fun <T, Args : LockQueueArgs<T>> doWithLock(
526543
args: Args,
527544
queue: LockQueue<T, Args>,
@@ -554,6 +571,12 @@ class ReplForJupyterImpl(
554571
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
555572
LockQueueArgs<ListErrorsResult>
556573

574+
private data class SerializationArgs(
575+
val cellId: Int,
576+
val descriptorsState: Map<String, SerializedVariablesState>,
577+
override val callback: (SerializationReply) -> Unit
578+
) : LockQueueArgs<SerializationReply>
579+
557580
@JvmInline
558581
private value class LockQueue<T, Args : LockQueueArgs<T>>(
559582
private val args: AtomicReference<Args?> = AtomicReference()

src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,4 +928,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
928928
assertEquals("${values++}", state.value)
929929
}
930930
}
931+
932+
@Test
933+
fun testSerializationMessage() {
934+
val res = eval(
935+
"""
936+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
937+
""".trimIndent(),
938+
jupyterId = 1
939+
)
940+
val varsData = res.metadata.evaluatedVariablesState
941+
assertEquals(1, varsData.size)
942+
val listData = varsData["x"]!!
943+
assertTrue(listData.isContainer)
944+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
945+
val propertyName = listData.fieldDescriptor.entries.first().key
946+
947+
runBlocking {
948+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
949+
val data = result.descriptorsState
950+
assertTrue(data.isNotEmpty())
951+
952+
val innerList = data.entries.last().value!!
953+
assertTrue(innerList.isContainer)
954+
var receivedDescriptor = innerList.fieldDescriptor
955+
assertEquals(2, receivedDescriptor.size)
956+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
957+
958+
assertEquals(5, receivedDescriptor.size)
959+
var values = 1
960+
receivedDescriptor.forEach { (name, state) ->
961+
if (name == "size") {
962+
assertFalse(state!!.isContainer)
963+
assertTrue(state!!.fieldDescriptor.isEmpty())
964+
return@forEach
965+
}
966+
val fieldDescriptor = state!!.fieldDescriptor
967+
assertEquals(0, fieldDescriptor.size)
968+
assertTrue(state.isContainer)
969+
assertEquals("${values++}", state.value)
970+
}
971+
}
972+
}
973+
}
931974
}

0 commit comments

Comments
 (0)