Skip to content

Commit ef903f1

Browse files
Add new message type for vars serialization
1 parent 3a97511 commit ef903f1

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.config.LanguageInfo
2728
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2829
import kotlin.reflect.KClass
@@ -87,7 +88,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8788
COMM_CLOSE(CommClose::class),
8889

8990
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
90-
LIST_ERRORS_REPLY(ListErrorsReply::class);
91+
LIST_ERRORS_REPLY(ListErrorsReply::class),
92+
93+
SERIALIZATION_REQUEST(SerializationRequest::class),
94+
SERIALIZATION_REPLY(SerializationReply::class);
9195

9296
// TODO: add custom commands
9397
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -555,6 +559,18 @@ class ListErrorsReply(
555559
val errors: List<ScriptDiagnostic>
556560
) : MessageContent()
557561

562+
@Serializable
563+
class SerializationRequest(
564+
val cellId: Int,
565+
val descriptorsState: Map<String, SerializedVariablesState>
566+
) : MessageContent()
567+
568+
@Serializable
569+
class SerializationReply(
570+
val cellId: Int,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
572+
) : MessageContent()
573+
558574
@Serializable(MessageDataSerializer::class)
559575
data class MessageData(
560576
val header: MessageHeader? = null,

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
321321
}
322322
}
323323
}
324+
is SerializationRequest -> {
325+
GlobalScope.launch(Dispatchers.Default) {
326+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
327+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
328+
}
329+
}
330+
}
324331
is IsCompleteRequest -> {
325332
// We are in console mode, so switch off all the loggers
326333
if (mainLoggerLevel() != Level.OFF) disableLogging()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2929
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
3030
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
3131
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
32+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
3233
import org.jetbrains.kotlinx.jupyter.config.catchAll
3334
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3435
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -133,6 +134,8 @@ interface ReplForJupyter {
133134

134135
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
135136

137+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
138+
136139
val homeDir: File?
137140

138141
val currentClasspath: Collection<String>
@@ -549,6 +552,20 @@ class ReplForJupyterImpl(
549552
return ListErrorsResult(args.code, errorsList)
550553
}
551554

555+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
556+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
557+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
558+
}
559+
560+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
561+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
562+
args.descriptorsState.forEach { (name, state) ->
563+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
564+
}
565+
return SerializationReply(args.cellId, resultMap)
566+
}
567+
568+
552569
private fun <T, Args : LockQueueArgs<T>> doWithLock(
553570
args: Args,
554571
queue: LockQueue<T, Args>,
@@ -581,6 +598,12 @@ class ReplForJupyterImpl(
581598
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
582599
LockQueueArgs<ListErrorsResult>
583600

601+
private data class SerializationArgs(
602+
val cellId: Int,
603+
val descriptorsState: Map<String, SerializedVariablesState>,
604+
override val callback: (SerializationReply) -> Unit
605+
) : LockQueueArgs<SerializationReply>
606+
584607
@JvmInline
585608
private value class LockQueue<T, Args : LockQueueArgs<T>>(
586609
private val args: AtomicReference<Args?> = AtomicReference()

src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,4 +832,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
832832
assertEquals("${values++}", state.value)
833833
}
834834
}
835+
836+
@Test
837+
fun testSerializationMessage() {
838+
val res = eval(
839+
"""
840+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
841+
""".trimIndent(),
842+
jupyterId = 1
843+
)
844+
val varsData = res.metadata.evaluatedVariablesState
845+
assertEquals(1, varsData.size)
846+
val listData = varsData["x"]!!
847+
assertTrue(listData.isContainer)
848+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
849+
val propertyName = listData.fieldDescriptor.entries.first().key
850+
851+
runBlocking {
852+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
853+
val data = result.descriptorsState
854+
assertTrue(data.isNotEmpty())
855+
856+
val innerList = data.entries.last().value!!
857+
assertTrue(innerList.isContainer)
858+
var receivedDescriptor = innerList.fieldDescriptor
859+
assertEquals(2, receivedDescriptor.size)
860+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
861+
862+
assertEquals(5, receivedDescriptor.size)
863+
var values = 1
864+
receivedDescriptor.forEach { (name, state) ->
865+
if (name == "size") {
866+
assertFalse(state!!.isContainer)
867+
assertTrue(state!!.fieldDescriptor.isEmpty())
868+
return@forEach
869+
}
870+
val fieldDescriptor = state!!.fieldDescriptor
871+
assertEquals(0, fieldDescriptor.size)
872+
assertTrue(state.isContainer)
873+
assertEquals("${values++}", state.value)
874+
}
875+
}
876+
}
877+
}
835878
}

0 commit comments

Comments
 (0)