Skip to content

Commit a92bb3f

Browse files
nikolay-egorovileasile
authored andcommitted
Add possibility to use serialization request with only top-level descriptor name, not cellID
1 parent 4582d38 commit a92bb3f

File tree

12 files changed

+152
-17
lines changed

12 files changed

+152
-17
lines changed

jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/VariableState.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@ import kotlin.reflect.jvm.isAccessible
66
import java.lang.reflect.Field
77

88
interface VariableState {
9-
// val property: KProperty<*>
109
val property: Field
1110
val scriptInstance: Any?
1211
val stringValue: String?
1312
val value: Result<Any?>
1413
}
1514

1615
data class VariableStateImpl(
17-
// override val property: KProperty1<Any, *>,
1816
override val property: Field,
1917
override val scriptInstance: Any,
2018
) : VariableState {
2119
private var cachedValue: Result<Any?> = Result.success(null)
2220
private var isRecursive: Boolean = false
2321

22+
// use of Java 9 required
23+
@SuppressWarnings("DEPRECATION")
2424
fun update(): Boolean {
2525
val wasAccessible = property.isAccessible
2626
property.isAccessible = true

jupyter-lib/shared-compiler/src/main/kotlin/org/jetbrains/kotlinx/jupyter/compiler/util/serializedCompiledScript.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ data class SerializedVariablesState(
3030
val fieldDescriptor: MutableMap<String, SerializedVariablesState?> = mutableMapOf()
3131
}
3232

33+
@Serializable
34+
class SerializationReply(
35+
val cellId: Int = 1,
36+
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
37+
)
38+
3339
@Serializable
3440
class EvaluatedSnippetMetadata(
3541
val newClasspath: Classpath = emptyList(),

src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class NotebookImpl(
128128
private var currentCellVariables = mapOf<Int, Set<String>>()
129129
private val history = arrayListOf<CodeCellImpl>()
130130
private var mainCellCreated = false
131+
private val unchangedVariables: MutableSet<String> = mutableSetOf()
131132

132133
val displays = DisplayContainerImpl()
133134

@@ -147,12 +148,16 @@ class NotebookImpl(
147148
fun updateVariablesState(evaluator: InternalEvaluator) {
148149
variablesState += evaluator.variablesHolder
149150
currentCellVariables = evaluator.cellVariables
151+
unchangedVariables.clear()
152+
unchangedVariables.addAll(evaluator.getUnchangedVariables())
150153
}
151154

152155
fun updateVariablesState(varsStateUpdate: Map<String, VariableState>) {
153156
variablesState += varsStateUpdate
154157
}
155158

159+
fun unchangedVariables(): Set<String> = unchangedVariables
160+
156161
fun variablesReportAsHTML(): String {
157162
return generateHTMLVarsReport(variablesState)
158163
}

src/main/kotlin/org/jetbrains/kotlinx/jupyter/message_types.kt

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ enum class MessageType(val contentClass: KClass<out MessageContent>) {
9393
SERIALIZATION_REQUEST(SerializationRequest::class),
9494
SERIALIZATION_REPLY(SerializationReply::class);
9595

96-
// TODO: add custom commands
97-
// this custom message should be supported on client-side. either JS or Idea Plugin
98-
9996
val type: String
10097
get() = name.lowercase()
10198
}
@@ -562,12 +559,13 @@ class ListErrorsReply(
562559
@Serializable
563560
class SerializationRequest(
564561
val cellId: Int,
565-
val descriptorsState: Map<String, SerializedVariablesState>
562+
val descriptorsState: Map<String, SerializedVariablesState>,
563+
val topLevelDescriptorName: String = ""
566564
) : MessageContent()
567565

568566
@Serializable
569567
class SerializationReply(
570-
val cellId: Int,
568+
val cellId: Int = 1,
571569
val descriptorsState: Map<String, SerializedVariablesState> = emptyMap()
572570
) : MessageContent()
573571

src/main/kotlin/org/jetbrains/kotlinx/jupyter/protocol.kt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,21 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
306306
is CommInfoRequest -> {
307307
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_INFO_REPLY, content = CommInfoReply(mapOf())))
308308
}
309+
is CommOpen -> {
310+
if (!content.commId.equals(MessageType.SERIALIZATION_REQUEST.name, ignoreCase = true)) {
311+
send(makeReplyMessage(msg, MessageType.NONE))
312+
return
313+
}
314+
log.debug("Message type in CommOpen: $msg, ${msg.type}")
315+
val data = content.data ?: return sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY))
316+
317+
val messageContent = getVariablesDescriptorsFromJson(data)
318+
GlobalScope.launch(Dispatchers.Default) {
319+
repl.serializeVariables(messageContent.topLevelDescriptorName, messageContent.descriptorsState) { result ->
320+
sendWrapped(msg, makeReplyMessage(msg, MessageType.COMM_OPEN, content = result))
321+
}
322+
}
323+
}
309324
is CompleteRequest -> {
310325
connection.launchJob {
311326
repl.complete(content.code, content.cursorPos) { result ->
@@ -322,8 +337,14 @@ fun JupyterConnection.Socket.shellMessagesHandler(msg: Message, repl: ReplForJup
322337
}
323338
is SerializationRequest -> {
324339
GlobalScope.launch(Dispatchers.Default) {
325-
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
326-
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
340+
if (content.topLevelDescriptorName.isNotEmpty()) {
341+
repl.serializeVariables(content.topLevelDescriptorName, content.descriptorsState) { result ->
342+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
343+
}
344+
} else {
345+
repl.serializeVariables(content.cellId, content.descriptorsState) { result ->
346+
sendWrapped(msg, makeReplyMessage(msg, MessageType.SERIALIZATION_REPLY, content = result))
347+
}
327348
}
328349
}
329350
}

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl.kt

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ interface ReplForJupyter {
131131

132132
suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
133133

134+
suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit)
135+
134136
val homeDir: File?
135137

136138
val currentClasspath: Collection<String>
@@ -538,15 +540,26 @@ class ReplForJupyterImpl(
538540

539541
private val serializationQueue = LockQueue<SerializationReply, SerializationArgs>()
540542
override suspend fun serializeVariables(cellId: Int, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
541-
doWithLock(SerializationArgs(cellId, descriptorsState, callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
543+
doWithLock(SerializationArgs(descriptorsState, cellId = cellId, callback = callback), serializationQueue, SerializationReply(cellId, descriptorsState), ::doSerializeVariables)
544+
}
545+
546+
override suspend fun serializeVariables(topLevelVarName: String, descriptorsState: Map<String, SerializedVariablesState>, callback: (SerializationReply) -> Unit) {
547+
doWithLock(SerializationArgs(descriptorsState, topLevelVarName = topLevelVarName, callback = callback), serializationQueue, SerializationReply(), ::doSerializeVariables)
542548
}
543549

544550
private fun doSerializeVariables(args: SerializationArgs): SerializationReply {
545551
val resultMap = mutableMapOf<String, SerializedVariablesState>()
552+
val cellId = if (args.cellId != -1) args.cellId else {
553+
val watcherInfo = internalEvaluator.findVariableCell(args.topLevelVarName) + 1
554+
val finalAns = if (watcherInfo == - 1) 1 else watcherInfo
555+
finalAns
556+
}
546557
args.descriptorsState.forEach { (name, state) ->
547-
resultMap[name] = variablesSerializer.doIncrementalSerialization(args.cellId - 1, name, state)
558+
resultMap[name] = variablesSerializer.doIncrementalSerialization(cellId - 1, name, state)
548559
}
549-
return SerializationReply(args.cellId, resultMap)
560+
log.debug("Serialization cellID: $cellId")
561+
log.debug("Serialization answer: ${resultMap.entries.first().value.fieldDescriptor}")
562+
return SerializationReply(cellId, resultMap)
550563
}
551564

552565

@@ -583,11 +596,13 @@ class ReplForJupyterImpl(
583596
LockQueueArgs<ListErrorsResult>
584597

585598
private data class SerializationArgs(
586-
val cellId: Int,
587599
val descriptorsState: Map<String, SerializedVariablesState>,
600+
var cellId: Int = -1,
601+
val topLevelVarName: String = "",
588602
override val callback: (SerializationReply) -> Unit
589603
) : LockQueueArgs<SerializationReply>
590604

605+
591606
@JvmInline
592607
private value class LockQueue<T, Args : LockQueueArgs<T>>(
593608
private val args: AtomicReference<Args?> = AtomicReference()

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/InternalEvaluator.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,14 @@ interface InternalEvaluator {
3030
* returns empty data or null
3131
*/
3232
fun popAddedCompiledScripts(): SerializedCompiledScriptsData = SerializedCompiledScriptsData.EMPTY
33+
34+
/**
35+
* Get a cellId where a particular variable is declared
36+
*/
37+
fun findVariableCell(variableName: String): Int
38+
39+
/**
40+
* Returns a set of unaffected variables after execution
41+
*/
42+
fun getUnchangedVariables(): Set<String>
3343
}

src/main/kotlin/org/jetbrains/kotlinx/jupyter/repl/impl/InternalEvaluatorImpl.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ internal class InternalEvaluatorImpl(
4848
return SerializedCompiledScriptsData(scripts)
4949
}
5050

51+
override fun findVariableCell(variableName: String): Int {
52+
return variablesWatcher.findDeclarationAddress(variableName) ?: -1
53+
}
54+
5155
override var writeCompiledClasses: Boolean
5256
get() = classWriter != null
5357
set(value) {

src/main/kotlin/org/jetbrains/kotlinx/jupyter/serializationUtils.kt

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package org.jetbrains.kotlinx.jupyter
22

3+
import kotlinx.serialization.Serializable
4+
import kotlinx.serialization.json.Json
5+
import kotlinx.serialization.json.JsonObject
6+
import kotlinx.serialization.json.decodeFromJsonElement
37
import org.jetbrains.kotlinx.jupyter.api.VariableState
48
import org.jetbrains.kotlinx.jupyter.compiler.util.SerializedVariablesState
59
import java.lang.reflect.Field
10+
import kotlin.contracts.ExperimentalContracts
11+
import kotlin.contracts.contract
612
import kotlin.reflect.KClass
713
import kotlin.reflect.KProperty
814
import kotlin.reflect.KProperty1
@@ -23,6 +29,16 @@ enum class PropertiesType {
2329
MIXED
2430
}
2531

32+
@Serializable
33+
data class SerializedCommMessageContent(
34+
val topLevelDescriptorName: String,
35+
val descriptorsState: Map<String, SerializedVariablesState>
36+
)
37+
38+
fun getVariablesDescriptorsFromJson(json: JsonObject): SerializedCommMessageContent {
39+
return Json.decodeFromJsonElement<SerializedCommMessageContent>(json)
40+
}
41+
2642
class ProcessedSerializedVarsState(
2743
val serializedVariablesState: SerializedVariablesState,
2844
val propertiesData: PropertiesData? = null,
@@ -94,9 +110,11 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s
94110
} == true
95111
}
96112

97-
val kProperties = if (value != null) value::class.declaredMemberProperties else {
98-
null
99-
}
113+
val kProperties = try {
114+
if (value != null) value::class.declaredMemberProperties else {
115+
null
116+
}
117+
} catch (ex: Exception) {null}
100118
val serializedVersion = SerializedVariablesState(simpleTypeName, getProperString(value), true)
101119
val descriptors = serializedVersion.fieldDescriptor
102120
if (isDescriptorsNeeded) {
@@ -198,6 +216,11 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s
198216

199217
private val isSerializationActive: Boolean = System.getProperty(serializationSystemProperty)?.toBooleanStrictOrNull() ?: true
200218

219+
/**
220+
* Cache for not recomputing unchanged variables
221+
*/
222+
val serializedVariablesCache: MutableMap<String, SerializedVariablesState> = mutableMapOf()
223+
201224
fun serializeVariables(cellId: Int, variablesState: Map<String, VariableState>): Map<String, SerializedVariablesState> {
202225
if (!isSerializationActive) return emptyMap()
203226

@@ -375,6 +398,7 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s
375398
* Really wanted to use contracts here, but all usages should be provided with this annotation and,
376399
* perhaps, it may be a big overhead
377400
*/
401+
@OptIn(ExperimentalContracts::class)
378402
private fun iterateThrough(
379403
elem: Any,
380404
seenObjectsPerCell: MutableMap<RuntimeObjectWrapper, SerializedVariablesState>?,
@@ -383,6 +407,10 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s
383407
instancesPerState: MutableMap<SerializedVariablesState, Any?>,
384408
callInstance: Any
385409
) {
410+
contract {
411+
returns() implies (elem is Field || elem is KProperty1<*, *>)
412+
}
413+
386414
val name = if (elem is Field) elem.name else (elem as KProperty1<Any, *>).name
387415
val value = if (elem is Field) tryGetValueFromProperty(elem, callInstance).toObjectWrapper()
388416
else {
@@ -491,6 +519,8 @@ class VariablesSerializer(private val serializationDepth: Int = 2, private val s
491519
return value
492520
}
493521

522+
// use of Java 9 required
523+
@SuppressWarnings("DEPRECATION")
494524
private fun tryGetValueFromProperty(property: Field, callInstance: Any): Any? {
495525
// some fields may be optimized out like array size. Thus, calling it.isAccessible would return error
496526
val canAccess = try {

src/main/kotlin/org/jetbrains/kotlinx/jupyter/util.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ fun ResultsRenderersProcessor.registerDefaultRenderers() {
7373
* Stores info about where a variable Y was declared and info about what are they at the address X.
7474
* K: key, stands for a way of addressing variables, e.g. address.
7575
* V: value, from Variable, choose any suitable type for your variable reference.
76-
* Default: T=Int, V=String
76+
* Default: K=Int, V=String
7777
*/
7878
class VariablesUsagesPerCellWatcher<K : Any, V : Any> {
7979
val cellVariables = mutableMapOf<K, MutableSet<V>>()
@@ -106,5 +106,7 @@ class VariablesUsagesPerCellWatcher<K : Any, V : Any> {
106106
}
107107
}
108108

109+
fun findDeclarationAddress(variableRef: V) = variablesDeclarationInfo[variableRef]
110+
109111
fun ensureStorageCreation(address: K) = cellVariables.putIfAbsent(address, mutableSetOf())
110112
}

0 commit comments

Comments
 (0)