diff --git a/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/JexlExpression.java b/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/JexlExpression.java index d7bf80ab..c045b59f 100644 --- a/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/JexlExpression.java +++ b/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/JexlExpression.java @@ -1,9 +1,20 @@ package at.ac.uibk.dps.cirrina.execution.object.expression; import at.ac.uibk.dps.cirrina.execution.object.context.Extent; +import java.io.IOException; +import java.lang.reflect.Array; +import java.util.Collection; import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.NoSuchElementException; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import org.apache.commons.jexl3.JexlArithmetic; import org.apache.commons.jexl3.JexlBuilder; import org.apache.commons.jexl3.JexlContext; import org.apache.commons.jexl3.JexlEngine; @@ -49,15 +60,18 @@ private static JexlEngine getJexlEngine() { final Map namespaces = new HashMap<>(); namespaces.put("math", Math.class); // Enable math methods, e.g. math:sin(x), math:min(x, y), math:random() - namespaces.put("utility", Utility.class); + namespaces.put("std", Stdlib.class); - var features = new JexlFeatures().sideEffectGlobal(false).sideEffect(true); + var features = new JexlFeatures().sideEffectGlobal(true).sideEffect(true); return new JexlBuilder() + .arithmetic(new CsmlArithmetic(true)) .features(features) .cache(CACHE_SIZE) .namespaces(namespaces) .permissions(JexlPermissions.UNRESTRICTED) + .strict(true) + .silent(false) .create(); } @@ -96,11 +110,149 @@ public Object get(String key) { } @Override - public void set(String key, Object value) {} + public void set(String key, Object value) { + try { + extent.trySet(key, value); + } catch (IOException e) { + throw new NoSuchElementException(String.format("Variable not found: %s", key)); + } + } @Override public boolean has(String key) { return extent.resolve(key).isPresent(); } } + + /** + * CsmlArithmetic extends JexlArithmetic to provide custom operator overloading + * for collections, arrays, and maps in JEXL expressions. + * + *

Supports: + *

+ * + *

All operations are side-effect-free. + */ + private static class CsmlArithmetic extends JexlArithmetic { + + /** + * Constructs a CsmlArithmetic instance with the specified strict mode. + * + * @param strict if true, the arithmetic engine runs in strict mode, + * where it throws exceptions for errors + */ + public CsmlArithmetic(boolean strict) { + super(strict); + } + + /** + * Adds two objects together + * + * @param left the first operand + * @param right the second operand + * @return the result of the addition + */ + @Override + public Object add(Object left, Object right) { + // Left is a list: concatenate left and right, result is a List + if (left instanceof List) { + return Stream.concat(toStream(left), toStream(right)).collect(Collectors.toList()); + } + + // Left is a set: concatenate left and right, result is a LinkedHashSet to preserve uniqueness + if (left instanceof Set) { + return Stream.concat(toStream(left), toStream(right)).collect( + Collectors.toCollection(LinkedHashSet::new) + ); + } + + // Left is an array: concatenate left and right streams, result is an Object[] + if (left != null && left.getClass().isArray()) { + return Stream.concat(toStream(left), toStream(right)).toArray(); + } + + // Left and right are maps: merge entries, right-hand side overwrites left-hand side keys + if (left instanceof Map lm && right instanceof Map rm) { + return Stream.concat(lm.entrySet().stream(), rm.entrySet().stream()).collect( + Collectors.toMap(Entry::getKey, Entry::getValue, (oldV, newV) -> newV, HashMap::new) + ); + } + + // Delegate to default arithmetic + return super.add(left, right); + } + + @Override + public Object subtract(Object left, Object right) { + // Left is List/Set/Array: remove elements present in right + if (isIterableLike(left) && isIterableLike(right)) { + Set rightSet = toStream(right).collect(Collectors.toSet()); + Stream leftStream = toStream(left).filter(e -> !rightSet.contains(e)); + + if (left instanceof List) { + return leftStream.collect(Collectors.toList()); + } + + if (left instanceof Set) { + return leftStream.collect(Collectors.toCollection(LinkedHashSet::new)); + } + + if (left.getClass().isArray()) { + Class componentType = left.getClass().getComponentType(); + return leftStream.toArray(size -> (Object[]) Array.newInstance(componentType, size)); + } + } + + // Left is Map, right is Map or iterable of keys + if (left instanceof Map lm) { + Set keysToRemove = Stream.of(right) + .flatMap(r -> { + if (r instanceof Map rm) { + return rm.keySet().stream(); + } else if (isIterableLike(r)) { + return toStream(r); + } else { + return Stream.of(r); + } + }) + .collect(Collectors.toSet()); + + return lm + .entrySet() + .stream() + .filter(e -> !keysToRemove.contains(e.getKey())) + .collect( + Collectors.toMap( + Entry::getKey, + Entry::getValue, + (a, b) -> b, + () -> new HashMap<>(lm.size()) + ) + ); + } + + // Delegate to default arithmetic + return super.subtract(left, right); + } + + private static boolean isIterableLike(Object o) { + return o instanceof Collection || (o != null && o.getClass().isArray()); + } + + private static Stream toStream(Object o) { + if (o instanceof Collection c) { + return c.stream().map(x -> x); + } + if (o != null && o.getClass().isArray()) { + return IntStream.range(0, Array.getLength(o)).mapToObj(i -> Array.get(o, i)); + } + return Stream.empty(); + } + } } diff --git a/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Stdlib.kt b/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Stdlib.kt new file mode 100644 index 00000000..cf26748f --- /dev/null +++ b/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Stdlib.kt @@ -0,0 +1,17 @@ +package at.ac.uibk.dps.cirrina.execution.`object`.expression + +import java.util.* + +class Stdlib { + companion object { + @JvmStatic + fun genRandPayload(sizes: IntArray): ByteArray { + val rand = Random() + + val randomIndex = rand.nextInt(sizes.size) + val selectedSize = sizes[randomIndex] + + return ByteArray(selectedSize) + } + } +} diff --git a/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Utility.java b/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Utility.java deleted file mode 100644 index d5aa86fc..00000000 --- a/src/main/java/at/ac/uibk/dps/cirrina/execution/object/expression/Utility.java +++ /dev/null @@ -1,15 +0,0 @@ -package at.ac.uibk.dps.cirrina.execution.object.expression; - -import java.util.Random; - -public final class Utility { - - public static byte[] genRandPayload(int[] sizes) { - final var rand = new Random(); - - final var randomIndex = rand.nextInt(sizes.length); - final var selectedSize = sizes[randomIndex]; - - return new byte[selectedSize]; - } -} diff --git a/src/test/java/at/ac/uibk/dps/cirrina/execution/object/expression/ExpressionTest.java b/src/test/java/at/ac/uibk/dps/cirrina/execution/object/expression/ExpressionTest.java index 466b80cd..54773779 100644 --- a/src/test/java/at/ac/uibk/dps/cirrina/execution/object/expression/ExpressionTest.java +++ b/src/test/java/at/ac/uibk/dps/cirrina/execution/object/expression/ExpressionTest.java @@ -1,15 +1,20 @@ package at.ac.uibk.dps.cirrina.execution.object.expression; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import at.ac.uibk.dps.cirrina.execution.object.context.Extent; import at.ac.uibk.dps.cirrina.execution.object.context.InMemoryContext; import java.nio.ByteBuffer; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import org.junit.jupiter.api.Test; class ExpressionTest { @@ -50,6 +55,244 @@ void testExpression() throws Exception { } } + @Test + void testArrayArithmetic() { + try (var context = new InMemoryContext(true)) { + assertDoesNotThrow(() -> { + var extent = new Extent(context); + + // Array with 1, 2, 3 + context.create("someArray", ExpressionBuilder.from("[1, 2, 3]").build().execute(extent)); + + // Add 4, 5, 6 + ExpressionBuilder.from("someArray = someArray + [4]").build().execute(extent); + ExpressionBuilder.from("someArray = someArray + {5}").build().execute(extent); + ExpressionBuilder.from("someArray = someArray + [6, ...]").build().execute(extent); + + assertArrayEquals( + new Object[] { 1, 2, 3, 4, 5, 6 }, + (Object[]) extent.resolve("someArray").get() + ); + + // Assert presence + assertEquals(true, ExpressionBuilder.from("someArray.contains(1)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someArray.contains(2)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someArray.contains(3)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someArray.contains(4)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someArray.contains(5)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someArray.contains(6)").build().execute(extent)); + + // Remove 4 + ExpressionBuilder.from("someArray = someArray - [4]").build().execute(extent); + + assertArrayEquals( + new Object[] { 1, 2, 3, 5, 6 }, + (Object[]) extent.resolve("someArray").get() + ); + + // Remove 5 + ExpressionBuilder.from("someArray = someArray - {5}").build().execute(extent); + + assertArrayEquals( + new Object[] { 1, 2, 3, 6 }, + (Object[]) extent.resolve("someArray").get() + ); + + // Remove 6 + ExpressionBuilder.from("someArray = someArray - [6, ...]").build().execute(extent); + + assertArrayEquals(new Object[] { 1, 2, 3 }, (Object[]) extent.resolve("someArray").get()); + + // Assert absence + assertEquals( + false, + ExpressionBuilder.from("someArray.contains(4)").build().execute(extent) + ); + assertEquals( + false, + ExpressionBuilder.from("someArray.contains(5)").build().execute(extent) + ); + assertEquals( + false, + ExpressionBuilder.from("someArray.contains(6)").build().execute(extent) + ); + }); + } + } + + @Test + void testListArithmetic() { + try (var context = new InMemoryContext(true)) { + assertDoesNotThrow(() -> { + var extent = new Extent(context); + + // List with 1, 2, 3 + context.create( + "someList", + ExpressionBuilder.from("[1, 2, 3, ...]").build().execute(extent) + ); + + // Add 4, 5, 6 + ExpressionBuilder.from("someList = someList + [4]").build().execute(extent); + ExpressionBuilder.from("someList = someList + {5}").build().execute(extent); + ExpressionBuilder.from("someList = someList + [6, ...]").build().execute(extent); + + assertIterableEquals(List.of(1, 2, 3, 4, 5, 6), (List) extent.resolve("someList").get()); + + // Assert presence + assertEquals(true, ExpressionBuilder.from("someList.contains(1)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(2)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(3)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(4)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(5)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(6)").build().execute(extent)); + + // Remove 4 + ExpressionBuilder.from("someList = someList - [4]").build().execute(extent); + + assertIterableEquals(List.of(1, 2, 3, 5, 6), (List) extent.resolve("someList").get()); + + // Remove 5 + ExpressionBuilder.from("someList = someList - {5}").build().execute(extent); + + assertIterableEquals(List.of(1, 2, 3, 6), (List) extent.resolve("someList").get()); + + // Remove 6 + ExpressionBuilder.from("someList = someList - [6, ...]").build().execute(extent); + + assertIterableEquals(List.of(1, 2, 3), (List) extent.resolve("someList").get()); + + // Assert absence + assertEquals(false, ExpressionBuilder.from("someList.contains(4)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someList.contains(5)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someList.contains(6)").build().execute(extent)); + }); + } + } + + @Test + void testSetArithmetic() { + try (var context = new InMemoryContext(true)) { + assertDoesNotThrow(() -> { + var extent = new Extent(context); + + // Set with 1, 2, 3 + context.create("someList", ExpressionBuilder.from("{1, 2, 3}").build().execute(extent)); + + // Add 4, 5, 6 + ExpressionBuilder.from("someList = someList + [4]").build().execute(extent); + ExpressionBuilder.from("someList = someList + {5}").build().execute(extent); + ExpressionBuilder.from("someList = someList + [6, ...]").build().execute(extent); + + assertIterableEquals( + new LinkedHashSet<>(List.of(1, 2, 3, 4, 5, 6)), + (Set) extent.resolve("someList").get() + ); + + // Assert presence + assertEquals(true, ExpressionBuilder.from("someList.contains(1)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(2)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(3)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(4)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(5)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someList.contains(6)").build().execute(extent)); + + // Remove 4 + ExpressionBuilder.from("someList = someList - [4]").build().execute(extent); + + assertIterableEquals( + new LinkedHashSet<>(List.of(1, 2, 3, 5, 6)), + (Set) extent.resolve("someList").get() + ); + + // Remove 5 + ExpressionBuilder.from("someList = someList - {5}").build().execute(extent); + + assertIterableEquals( + new LinkedHashSet<>(List.of(1, 2, 3, 6)), + (Set) extent.resolve("someList").get() + ); + + // Remove 6 + ExpressionBuilder.from("someList = someList - [6, ...]").build().execute(extent); + + assertIterableEquals( + new LinkedHashSet<>(List.of(1, 2, 3)), + (Set) extent.resolve("someList").get() + ); + + // Assert absence + assertEquals(false, ExpressionBuilder.from("someList.contains(4)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someList.contains(5)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someList.contains(6)").build().execute(extent)); + }); + } + } + + @Test + void testMapArithmetic() { + try (var context = new InMemoryContext(true)) { + assertDoesNotThrow(() -> { + var extent = new Extent(context); + + // Map with 1:2 + context.create("someMap", ExpressionBuilder.from("{1:2}").build().execute(extent)); + + // Add 3:4, 5:6, 7:8, 9:10, 11:12 + ExpressionBuilder.from("someMap = someMap + {3:4}").build().execute(extent); + ExpressionBuilder.from("someMap = someMap + {5:6}").build().execute(extent); + ExpressionBuilder.from("someMap = someMap + {7:8}").build().execute(extent); + ExpressionBuilder.from("someMap = someMap + {9:10}").build().execute(extent); + ExpressionBuilder.from("someMap = someMap + {11:12}").build().execute(extent); + + assertEquals( + Map.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), + extent.resolve("someMap").get() + ); + + // Assert presence + assertEquals(true, ExpressionBuilder.from("someMap.contains(1)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someMap.contains(3)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someMap.contains(5)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someMap.contains(7)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someMap.contains(9)").build().execute(extent)); + assertEquals(true, ExpressionBuilder.from("someMap.contains(11)").build().execute(extent)); + + // Remove 3:4 + ExpressionBuilder.from("someMap = someMap - {3:4}").build().execute(extent); + + assertEquals(Map.of(1, 2, 5, 6, 7, 8, 9, 10, 11, 12), extent.resolve("someMap").get()); + + // Remove 5:6 + ExpressionBuilder.from("someMap = someMap - [5]").build().execute(extent); + + assertEquals(Map.of(1, 2, 7, 8, 9, 10, 11, 12), extent.resolve("someMap").get()); + + // Remove 7:8 + ExpressionBuilder.from("someMap = someMap - [7, ...]").build().execute(extent); + + assertEquals(Map.of(1, 2, 9, 10, 11, 12), extent.resolve("someMap").get()); + + // Remove 9:10 + ExpressionBuilder.from("someMap = someMap - {9}").build().execute(extent); + + assertEquals(Map.of(1, 2, 11, 12), extent.resolve("someMap").get()); + + // Remove 11:12 + ExpressionBuilder.from("someMap = someMap - 11").build().execute(extent); + + assertEquals(Map.of(1, 2), extent.resolve("someMap").get()); + + // Assert absence + assertEquals(false, ExpressionBuilder.from("someMap.contains(3)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someMap.contains(5)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someMap.contains(7)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someMap.contains(9)").build().execute(extent)); + assertEquals(false, ExpressionBuilder.from("someMap.contains(11)").build().execute(extent)); + }); + } + } + @Test void testUtility() throws Exception { try (var context = new InMemoryContext(true)) { @@ -58,7 +301,7 @@ void testUtility() throws Exception { for (int i = 0; i < 100; ++i) { final var bytes = ExpressionBuilder.from( - "utility:genRandPayload([1024, 1024 * 10, 1024 * 100, 1024 * 1000])" + "std:genRandPayload([1024, 1024 * 10, 1024 * 100, 1024 * 1000])" ) .build() .execute(extent); @@ -105,9 +348,6 @@ void testExpressionNegative() throws Exception { assertThrows(UnsupportedOperationException.class, () -> ExpressionBuilder.from("1 + ").build().execute(extent) ); - assertThrows(UnsupportedOperationException.class, () -> - ExpressionBuilder.from("varOneInt = 2").build().execute(extent) - ); // Throws at runtime assertThrows(UnsupportedOperationException.class, () -> diff --git a/src/test/java/at/ac/uibk/dps/cirrina/runtime/CompleteTest.kt b/src/test/java/at/ac/uibk/dps/cirrina/runtime/CompleteTest.kt index dbdce29c..16615ca5 100644 --- a/src/test/java/at/ac/uibk/dps/cirrina/runtime/CompleteTest.kt +++ b/src/test/java/at/ac/uibk/dps/cirrina/runtime/CompleteTest.kt @@ -9,7 +9,7 @@ import at.ac.uibk.dps.cirrina.execution.`object`.event.Event import at.ac.uibk.dps.cirrina.execution.`object`.event.EventHandler import at.ac.uibk.dps.cirrina.execution.`object`.exchange.ContextVariableProtos import at.ac.uibk.dps.cirrina.execution.`object`.exchange.EventProtos -import at.ac.uibk.dps.cirrina.execution.`object`.expression.Utility +import at.ac.uibk.dps.cirrina.execution.`object`.expression.Stdlib import at.ac.uibk.dps.cirrina.execution.service.RandomServiceImplementationSelector import at.ac.uibk.dps.cirrina.execution.service.ServiceImplementationBuilder import at.ac.uibk.dps.cirrina.io.plantuml.CollaborativeStateMachineExporter @@ -139,12 +139,12 @@ class CompleteTest { } @Test - fun testUtility() { + fun testStdlib() { val sizes = intArrayOf(10, 50, 100, 200, 500) val sizeSet = HashSet() repeat(100) { - val payload = Utility.genRandPayload(sizes) + val payload = Stdlib.genRandPayload(sizes) assertNotNull(payload) assertTrue(sizes.any { it == payload.size }) sizeSet.add(payload.size)