Skip to content

Commit 17a85c6

Browse files
nikolay-egorovileasile
authored andcommitted
Add new message type for vars serialization
1 parent e1a1e2f commit 17a85c6

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
2323
import kotlinx.serialization.json.encodeToJsonElement
2424
import kotlinx.serialization.json.jsonObject
2525
import kotlinx.serialization.serializer
26+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
2627
import org.jetbrains.kotlinx.jupyter.config.LanguageInfo
2728
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
2829
import kotlin.reflect.KClass
@@ -87,7 +88,10 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
8788
COMM_CLOSE(CommClose::class),
8889

8990
LIST_ERRORS_REQUEST(ListErrorsRequest::class),
90-
LIST_ERRORS_REPLY(ListErrorsReply::class);
91+
LIST_ERRORS_REPLY(ListErrorsReply::class),
92+
93+
SERIALIZATION_REQUEST(SerializationRequest::class),
94+
SERIALIZATION_REPLY(SerializationReply::class);
9195

9296
// TODO: add custom commands
9397
// this custom message should be supported on client-side. either JS or Idea Plugin
@@ -555,6 +559,18 @@ class ListErrorsReply(
555559
val errors: List<ScriptDiagnostic>
556560
) : MessageContent()
557561

562+
@Serializable
563+
class SerializationRequest(
564+
val cellId: Int,
565+
val descriptorsState: Map<String, SerializedVariablesState>
566+
) : MessageContent()
567+
568+
@Serializable
569+
class SerializationReply(
570+
val cellId: Int,
571+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
572+
) : MessageContent()
573+
558574
@Serializable(MessageDataSerializer::class)
559575
data class MessageData(
560576
val header: MessageHeader? = null,

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,13 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
320320
}
321321
}
322322
}
323+
is SerializationRequest -> {
324+
GlobalScope.launch(Dispatchers.Default) {
325+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
326+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
327+
}
328+
}
329+
}
323330
is IsCompleteRequest -> {
324331
// We are in console mode, so switch off all the loggers
325332
if (mainLoggerLevel() != Level.OFF) disableLogging()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.jetbrains.kotlinx.jupyter.compiler.ScriptImportsCollector
2727
import org.jetbrains.kotlinx.jupyter.compiler.util.Classpath
2828
import org.jetbrains.kotlinx.jupyter.compiler.util.EvaluatedSnippetMetadata
2929
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedCompiledScriptsData
30+
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
3031
import org.jetbrains.kotlinx.jupyter.config.catchAll
3132
import org.jetbrains.kotlinx.jupyter.config.getCompilationConfiguration
3233
import org.jetbrains.kotlinx.jupyter.dependencies.JupyterScriptDependenciesResolverImpl
@@ -128,6 +129,8 @@ interface ReplForJupyter {
128129

129130
suspend fun listErrors(code: Code, callback: (ListErrorsResult) -> Unit)
130131

132+
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
133+
131134
val homeDir: File?
132135

133136
val currentClasspath: Collection<String>
@@ -533,6 +536,20 @@ class ReplForJupyterImpl(
533536
return ListErrorsResult(args.code, errorsList)
534537
}
535538

539+
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
540+
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
541+
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
542+
}
543+
544+
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
545+
val resultMap = mutableMapOf<String, SerializedVariablesState>()
546+
args.descriptorsState.forEach { (name, state) ->
547+
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
548+
}
549+
return SerializationReply(args.cellId, resultMap)
550+
}
551+
552+
536553
private fun <T, Args : LockQueueArgs<T>> doWithLock(
537554
args: Args,
538555
queue: LockQueue<T, Args>,
@@ -565,6 +582,12 @@ class ReplForJupyterImpl(
565582
private data class ListErrorsArgs(val code: String, override val callback: (ListErrorsResult) -> Unit) :
566583
LockQueueArgs<ListErrorsResult>
567584

585+
private data class SerializationArgs(
586+
val cellId: Int,
587+
val descriptorsState: Map<String, SerializedVariablesState>,
588+
override val callback: (SerializationReply) -> Unit
589+
) : LockQueueArgs<SerializationReply>
590+
568591
@JvmInline
569592
private value class LockQueue<T, Args : LockQueueArgs<T>>(
570593
private val args: AtomicReference<Args?> = AtomicReference()

src/test/kotlin/org/jetbrains/kotlinx/jupyter/test/repl/ReplTests.kt

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,4 +955,47 @@ class ReplVarsSerializationTest : AbstractSingleReplTest() {
955955
assertEquals("${values++}", state.value)
956956
}
957957
}
958+
959+
@Test
960+
fun testSerializationMessage() {
961+
val res = eval(
962+
"""
963+
val x = listOf(listOf(1), listOf(2), listOf(3), listOf(4))
964+
""".trimIndent(),
965+
jupyterId = 1
966+
)
967+
val varsData = res.metadata.evaluatedVariablesState
968+
assertEquals(1, varsData.size)
969+
val listData = varsData["x"]!!
970+
assertTrue(listData.isContainer)
971+
val actualContainer = listData.fieldDescriptor.entries.first().value!!
972+
val propertyName = listData.fieldDescriptor.entries.first().key
973+
974+
runBlocking {
975+
repl.serializeVariables(1, mapOf(propertyName to actualContainer)) { result ->
976+
val data = result.descriptorsState
977+
assertTrue(data.isNotEmpty())
978+
979+
val innerList = data.entries.last().value!!
980+
assertTrue(innerList.isContainer)
981+
var receivedDescriptor = innerList.fieldDescriptor
982+
assertEquals(2, receivedDescriptor.size)
983+
receivedDescriptor = receivedDescriptor.entries.last().value!!.fieldDescriptor
984+
985+
assertEquals(5, receivedDescriptor.size)
986+
var values = 1
987+
receivedDescriptor.forEach { (name, state) ->
988+
if (name == "size") {
989+
assertFalse(state!!.isContainer)
990+
assertTrue(state!!.fieldDescriptor.isEmpty())
991+
return@forEach
992+
}
993+
val fieldDescriptor = state!!.fieldDescriptor
994+
assertEquals(0, fieldDescriptor.size)
995+
assertTrue(state.isContainer)
996+
assertEquals("${values++}", state.value)
997+
}
998+
}
999+
}
1000+
}
9581001
}

0 commit comments

Comments
 (0)