Skip to content

Commit 79bebea

Browse files
Add new message type for vars serialization
1 parent 0131e41 commit 79bebea

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.config.LanguageInfo
2728
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2829
import kotlin.reflect.KClass
@@ -87,7 +88,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8788
COMM_CLOSE(CommClose::class),
8889

8990
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
90-
LIST_ERRORS_REPLY(ListErrorsReply::class);
91+
LIST_ERRORS_REPLY(ListErrorsReply::class),
92+
93+
SERIALIZATION_REQUEST(SerializationRequest::class),
94+
SERIALIZATION_REPLY(SerializationReply::class);
9195

9296
// TODO: add custom commands
9397
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -555,6 +559,18 @@ class ListErrorsReply(
555559
val errors: List<ScriptDiagnostic>
556560
) : MessageContent()
557561

562+
@Serializable
563+
class SerializationRequest(
564+
val cellId: Int,
565+
val descriptorsState: Map<String, SerializedVariablesState>
566+
) : MessageContent()
567+
568+
@Serializable
569+
class SerializationReply(
570+
val cellId: Int,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
572+
) : MessageContent()
573+
558574
@Serializable(MessageDataSerializer::class)
559575
data class MessageData(
560576
val header: MessageHeader? = null,

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
321321
}
322322
}
323323
}
324+
is SerializationRequest -> {
325+
GlobalScope.launch(Dispatchers.Default) {
326+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
327+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
328+
}
329+
}
330+
}
324331
is IsCompleteRequest -> {
325332
// We are in console mode, so switch off all the loggers
326333
if (mainLoggerLevel() != Level.OFF) disableLogging()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2929
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
3030
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
3131
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
32+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
3233
import org.jetbrains.kotlinx.jupyter.config.catchAll
3334
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3435
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -136,6 +137,8 @@ interface ReplForJupyter {
136137

137138
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
138139

140+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
141+
139142
val homeDir: File?
140143

141144
val currentClasspath: Collection<String>
@@ -557,6 +560,20 @@ class ReplForJupyterImpl(
557560
return ListErrorsResult(args.code, errorsList)
558561
}
559562

563+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
564+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
565+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
566+
}
567+
568+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
569+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
570+
args.descriptorsState.forEach { (name, state) ->
571+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
572+
}
573+
return SerializationReply(args.cellId, resultMap)
574+
}
575+
576+
560577
private fun <T, Args : LockQueueArgs<T>> doWithLock(
561578
args: Args,
562579
queue: LockQueue<T, Args>,
@@ -589,6 +606,12 @@ class ReplForJupyterImpl(
589606
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
590607
LockQueueArgs<ListErrorsResult>
591608

609+
private data class SerializationArgs(
610+
val cellId: Int,
611+
val descriptorsState: Map<String, SerializedVariablesState>,
612+
override val callback: (SerializationReply) -> Unit
613+
) : LockQueueArgs<SerializationReply>
614+
592615
@JvmInline
593616
private value class LockQueue<T, Args : LockQueueArgs<T>>(
594617
private val args: AtomicReference<Args?> = AtomicReference()

src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,4 +845,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
845845
assertEquals("${values++}", state.value)
846846
}
847847
}
848+
849+
@Test
850+
fun testSerializationMessage() {
851+
val res = eval(
852+
"""
853+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
854+
""".trimIndent(),
855+
jupyterId = 1
856+
)
857+
val varsData = res.metadata.evaluatedVariablesState
858+
assertEquals(1, varsData.size)
859+
val listData = varsData["x"]!!
860+
assertTrue(listData.isContainer)
861+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
862+
val propertyName = listData.fieldDescriptor.entries.first().key
863+
864+
runBlocking {
865+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
866+
val data = result.descriptorsState
867+
assertTrue(data.isNotEmpty())
868+
869+
val innerList = data.entries.last().value!!
870+
assertTrue(innerList.isContainer)
871+
var receivedDescriptor = innerList.fieldDescriptor
872+
assertEquals(2, receivedDescriptor.size)
873+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
874+
875+
assertEquals(5, receivedDescriptor.size)
876+
var values = 1
877+
receivedDescriptor.forEach { (name, state) ->
878+
if (name == "size") {
879+
assertFalse(state!!.isContainer)
880+
assertTrue(state!!.fieldDescriptor.isEmpty())
881+
return@forEach
882+
}
883+
val fieldDescriptor = state!!.fieldDescriptor
884+
assertEquals(0, fieldDescriptor.size)
885+
assertTrue(state.isContainer)
886+
assertEquals("${values++}", state.value)
887+
}
888+
}
889+
}
890+
}
848891
}

0 commit comments

Comments
 (0)