diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/BlockSignatureBuffer.java b/src/main/java/org/biscuitsec/biscuit/crypto/BlockSignatureBuffer.java index dafb3afe..beb01bb0 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/BlockSignatureBuffer.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/BlockSignatureBuffer.java @@ -1,33 +1,47 @@ package org.biscuitsec.biscuit.crypto; -import org.biscuitsec.biscuit.token.format.ExternalSignature; - import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Optional; +import org.biscuitsec.biscuit.token.format.ExternalSignature; + +public final class BlockSignatureBuffer { + public static final int HEADER_SIZE = 4; + + private BlockSignatureBuffer() {} -public class BlockSignatureBuffer { - public static byte[] getBufferSignature(PublicKey nextPubKey, byte[] data) { - return getBufferSignature(nextPubKey, data, Optional.empty()); - } + public static byte[] getBufferSignature(PublicKey nextPubKey, byte[] data) { + return getBufferSignature(nextPubKey, data, Optional.empty()); + } - public static byte[] getBufferSignature(PublicKey nextPubKey, byte[] data, Optional externalSignature) { - var buffer = ByteBuffer.allocate(4 + data.length + nextPubKey.toBytes().length + externalSignature.map((a) -> a.signature.length).orElse(0)).order(ByteOrder.LITTLE_ENDIAN); - buffer.put(data); - externalSignature.ifPresent(signature -> buffer.put(signature.signature)); - buffer.putInt(nextPubKey.algorithm.getNumber()); - buffer.put(nextPubKey.toBytes()); - buffer.flip(); - return buffer.array(); - } + public static byte[] getBufferSignature( + PublicKey nextPubKey, byte[] data, Optional externalSignature) { + var buffer = + ByteBuffer.allocate( + HEADER_SIZE + + data.length + + nextPubKey.toBytes().length + + externalSignature.map((a) -> a.getSignature().length).orElse(0)) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.put(data); + externalSignature.ifPresent(signature -> buffer.put(signature.getSignature())); + buffer.putInt(nextPubKey.getAlgorithm().getNumber()); + buffer.put(nextPubKey.toBytes()); + buffer.flip(); + return buffer.array(); + } - public static byte[] getBufferSealedSignature(PublicKey nextPubKey, byte[] data, byte[] blockSignature) { - var buffer = ByteBuffer.allocate(4 + data.length + nextPubKey.toBytes().length + blockSignature.length).order(ByteOrder.LITTLE_ENDIAN); - buffer.put(data); - buffer.putInt(nextPubKey.algorithm.getNumber()); - buffer.put(nextPubKey.toBytes()); - buffer.put(blockSignature); - buffer.flip(); - return buffer.array(); - } + public static byte[] getBufferSealedSignature( + PublicKey nextPubKey, byte[] data, byte[] blockSignature) { + var buffer = + ByteBuffer.allocate( + HEADER_SIZE + data.length + nextPubKey.toBytes().length + blockSignature.length) + .order(ByteOrder.LITTLE_ENDIAN); + buffer.put(data); + buffer.putInt(nextPubKey.getAlgorithm().getNumber()); + buffer.put(nextPubKey.toBytes()); + buffer.put(blockSignature); + buffer.flip(); + return buffer.array(); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/Ed25519KeyPair.java b/src/main/java/org/biscuitsec/biscuit/crypto/Ed25519KeyPair.java index dd290f29..7b7815d0 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/Ed25519KeyPair.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/Ed25519KeyPair.java @@ -1,6 +1,12 @@ package org.biscuitsec.biscuit.crypto; import biscuit.format.schema.Schema; +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.Signature; +import java.security.SignatureException; import net.i2p.crypto.eddsa.EdDSAEngine; import net.i2p.crypto.eddsa.EdDSAPrivateKey; import net.i2p.crypto.eddsa.EdDSAPublicKey; @@ -10,80 +16,75 @@ import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec; import org.biscuitsec.biscuit.token.builder.Utils; -import java.security.InvalidKeyException; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; -import java.security.Signature; -import java.security.SignatureException; - final class Ed25519KeyPair extends KeyPair { + private static final int BUFFER_SIZE = 32; - static final int SIGNATURE_LENGTH = 64; - - private final EdDSAPrivateKey privateKey; - private final EdDSAPublicKey publicKey; + public static final int SIGNATURE_LENGTH = 64; - private static final EdDSANamedCurveSpec ed25519 = EdDSANamedCurveTable.getByName(EdDSANamedCurveTable.ED_25519); + private final EdDSAPrivateKey privateKey; + private final EdDSAPublicKey publicKey; - public Ed25519KeyPair(byte[] bytes) { - EdDSAPrivateKeySpec privKeySpec = new EdDSAPrivateKeySpec(bytes, ed25519); - EdDSAPrivateKey privKey = new EdDSAPrivateKey(privKeySpec); + private static final EdDSANamedCurveSpec ED_25519 = + EdDSANamedCurveTable.getByName(EdDSANamedCurveTable.ED_25519); - EdDSAPublicKeySpec pubKeySpec = new EdDSAPublicKeySpec(privKey.getA(), ed25519); - EdDSAPublicKey pubKey = new EdDSAPublicKey(pubKeySpec); + Ed25519KeyPair(byte[] bytes) { + EdDSAPrivateKeySpec privKeySpec = new EdDSAPrivateKeySpec(bytes, ED_25519); + EdDSAPrivateKey privKey = new EdDSAPrivateKey(privKeySpec); - this.privateKey = privKey; - this.publicKey = pubKey; - } + EdDSAPublicKeySpec pubKeySpec = new EdDSAPublicKeySpec(privKey.getA(), ED_25519); + EdDSAPublicKey pubKey = new EdDSAPublicKey(pubKeySpec); - public Ed25519KeyPair(SecureRandom rng) { - byte[] b = new byte[32]; - rng.nextBytes(b); + this.privateKey = privKey; + this.publicKey = pubKey; + } - EdDSAPrivateKeySpec privKeySpec = new EdDSAPrivateKeySpec(b, ed25519); - EdDSAPrivateKey privKey = new EdDSAPrivateKey(privKeySpec); + Ed25519KeyPair(SecureRandom rng) { + byte[] b = new byte[BUFFER_SIZE]; + rng.nextBytes(b); - EdDSAPublicKeySpec pubKeySpec = new EdDSAPublicKeySpec(privKey.getA(), ed25519); - EdDSAPublicKey pubKey = new EdDSAPublicKey(pubKeySpec); + EdDSAPrivateKeySpec privKeySpec = new EdDSAPrivateKeySpec(b, ED_25519); + EdDSAPrivateKey privKey = new EdDSAPrivateKey(privKeySpec); - this.privateKey = privKey; - this.publicKey = pubKey; - } + EdDSAPublicKeySpec pubKeySpec = new EdDSAPublicKeySpec(privKey.getA(), ED_25519); + EdDSAPublicKey pubKey = new EdDSAPublicKey(pubKeySpec); - public Ed25519KeyPair(String hex) { - this(Utils.hexStringToByteArray(hex)); - } + this.privateKey = privKey; + this.publicKey = pubKey; + } - public static java.security.PublicKey decode(byte[] data) { - return new EdDSAPublicKey(new EdDSAPublicKeySpec(data, ed25519)); - } + Ed25519KeyPair(String hex) { + this(Utils.hexStringToByteArray(hex)); + } - public static Signature getSignature() throws NoSuchAlgorithmException { - return new EdDSAEngine(MessageDigest.getInstance(ed25519.getHashAlgorithm())); - } + public static java.security.PublicKey decode(byte[] data) { + return new EdDSAPublicKey(new EdDSAPublicKeySpec(data, ED_25519)); + } - @Override - public byte[] sign(byte[] data) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - Signature sgr = KeyPair.generateSignature(Schema.PublicKey.Algorithm.Ed25519); - sgr.initSign(privateKey); - sgr.update(data); - return sgr.sign(); - } + public static Signature getSignature() throws NoSuchAlgorithmException { + return new EdDSAEngine(MessageDigest.getInstance(ED_25519.getHashAlgorithm())); + } - @Override - public byte[] toBytes() { - return privateKey.getSeed(); - } + @Override + public byte[] sign(byte[] data) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + Signature sgr = KeyPair.generateSignature(Schema.PublicKey.Algorithm.Ed25519); + sgr.initSign(privateKey); + sgr.update(data); + return sgr.sign(); + } - @Override - public String toHex() { - return Utils.byteArrayToHexString(toBytes()); - } + @Override + public byte[] toBytes() { + return privateKey.getSeed(); + } - @Override - public PublicKey public_key() { - return new PublicKey(Schema.PublicKey.Algorithm.Ed25519, this.publicKey); - } + @Override + public String toHex() { + return Utils.byteArrayToHexString(toBytes()); + } + @Override + public PublicKey getPublicKey() { + return new PublicKey(Schema.PublicKey.Algorithm.Ed25519, this.publicKey); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/KeyDelegate.java b/src/main/java/org/biscuitsec/biscuit/crypto/KeyDelegate.java index 234eebdf..23e39e15 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/KeyDelegate.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/KeyDelegate.java @@ -2,13 +2,12 @@ import io.vavr.control.Option; - /** * Used to find the key associated with a key id * - * When the root key is changed, it might happen that multiple root keys are in use at the same time. - * Tokens can carry a root key id, that can be used to indicate which key will verify it. + *

When the root key is changed, it might happen that multiple root keys are in use at the same + * time. Tokens can carry a root key id, that can be used to indicate which key will verify it. */ public interface KeyDelegate { - public Option root_key(Option key_id); + Option getRootKey(Option keyId); } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/KeyPair.java b/src/main/java/org/biscuitsec/biscuit/crypto/KeyPair.java index f6107b75..62498eb5 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/KeyPair.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/KeyPair.java @@ -1,69 +1,66 @@ package org.biscuitsec.biscuit.crypto; - import biscuit.format.schema.Schema.PublicKey.Algorithm; -import net.i2p.crypto.eddsa.Utils; - import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.Signature; import java.security.SignatureException; +import net.i2p.crypto.eddsa.Utils; -/** - * Private and public key. - */ +/** Private and public key. */ public abstract class KeyPair implements Signer { - public static KeyPair generate(Algorithm algorithm) { - return generate(algorithm, new SecureRandom()); - } + public static KeyPair generate(Algorithm algorithm) { + return generate(algorithm, new SecureRandom()); + } - public static KeyPair generate(Algorithm algorithm, String hex) { - return generate(algorithm, Utils.hexToBytes(hex)); - } + public static KeyPair generate(Algorithm algorithm, String hex) { + return generate(algorithm, Utils.hexToBytes(hex)); + } - public static KeyPair generate(Algorithm algorithm, byte[] bytes) { - if (algorithm == Algorithm.Ed25519) { - return new Ed25519KeyPair(bytes); - } else if (algorithm == Algorithm.SECP256R1) { - return new SECP256R1KeyPair(bytes); - } else { - throw new IllegalArgumentException("Unsupported algorithm"); - } + public static KeyPair generate(Algorithm algorithm, byte[] bytes) { + if (algorithm == Algorithm.Ed25519) { + return new Ed25519KeyPair(bytes); + } else if (algorithm == Algorithm.SECP256R1) { + return new SECP256R1KeyPair(bytes); + } else { + throw new IllegalArgumentException("Unsupported algorithm"); } + } - public static KeyPair generate(Algorithm algorithm, SecureRandom rng) { - if (algorithm == Algorithm.Ed25519) { - return new Ed25519KeyPair(rng); - } else if (algorithm == Algorithm.SECP256R1) { - return new SECP256R1KeyPair(rng); - } else { - throw new IllegalArgumentException("Unsupported algorithm"); - } + public static KeyPair generate(Algorithm algorithm, SecureRandom rng) { + if (algorithm == Algorithm.Ed25519) { + return new Ed25519KeyPair(rng); + } else if (algorithm == Algorithm.SECP256R1) { + return new SECP256R1KeyPair(rng); + } else { + throw new IllegalArgumentException("Unsupported algorithm"); } + } - public static Signature generateSignature(Algorithm algorithm) throws NoSuchAlgorithmException { - if (algorithm == Algorithm.Ed25519) { - return Ed25519KeyPair.getSignature(); - } else if (algorithm == Algorithm.SECP256R1) { - return SECP256R1KeyPair.getSignature(); - } else { - throw new NoSuchAlgorithmException("Unsupported algorithm"); - } + public static Signature generateSignature(Algorithm algorithm) throws NoSuchAlgorithmException { + if (algorithm == Algorithm.Ed25519) { + return Ed25519KeyPair.getSignature(); + } else if (algorithm == Algorithm.SECP256R1) { + return SECP256R1KeyPair.getSignature(); + } else { + throw new NoSuchAlgorithmException("Unsupported algorithm"); } + } - public static boolean verify(PublicKey publicKey, byte[] data, byte[] signature) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - Signature sgr = KeyPair.generateSignature(publicKey.algorithm); - sgr.initVerify(publicKey.key); - sgr.update(data); - return sgr.verify(signature); - } + public static boolean verify(PublicKey publicKey, byte[] data, byte[] signature) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + Signature sgr = KeyPair.generateSignature(publicKey.getAlgorithm()); + sgr.initVerify(publicKey.getKey()); + sgr.update(data); + return sgr.verify(signature); + } - public abstract byte[] toBytes(); + public abstract byte[] toBytes(); - public abstract String toHex(); + public abstract String toHex(); - @Override - public abstract PublicKey public_key(); + @Override + public abstract PublicKey getPublicKey(); } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/PublicKey.java b/src/main/java/org/biscuitsec/biscuit/crypto/PublicKey.java index becf3cdb..aced9765 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/PublicKey.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/PublicKey.java @@ -2,117 +2,132 @@ import biscuit.format.schema.Schema; import biscuit.format.schema.Schema.PublicKey.Algorithm; +import com.google.protobuf.ByteString; +import java.util.Optional; +import java.util.Set; import net.i2p.crypto.eddsa.EdDSAPublicKey; import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.builder.Utils; -import com.google.protobuf.ByteString; import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey; -import java.util.Optional; -import java.util.Set; +public final class PublicKey { -public class PublicKey { + private final java.security.PublicKey key; + private final Algorithm algorithm; - public final java.security.PublicKey key; - public final Algorithm algorithm; + private static final Set SUPPORTED_ALGORITHMS = + Set.of(Algorithm.Ed25519, Algorithm.SECP256R1); - private static final Set SUPPORTED_ALGORITHMS = Set.of(Algorithm.Ed25519, Algorithm.SECP256R1); + public PublicKey(Algorithm algorithm, java.security.PublicKey publicKey) { + this.key = publicKey; + this.algorithm = algorithm; + } - public PublicKey(Algorithm algorithm, java.security.PublicKey public_key) { - this.key = public_key; - this.algorithm = algorithm; + public PublicKey(Algorithm algorithm, byte[] data) { + if (algorithm == Algorithm.Ed25519) { + this.key = Ed25519KeyPair.decode(data); + } else if (algorithm == Algorithm.SECP256R1) { + this.key = SECP256R1KeyPair.decode(data); + } else { + throw new IllegalArgumentException("Invalid algorithm"); } - - public PublicKey(Algorithm algorithm, byte[] data) { - if (algorithm == Algorithm.Ed25519) { - this.key = Ed25519KeyPair.decode(data); - } else if (algorithm == Algorithm.SECP256R1) { - this.key = SECP256R1KeyPair.decode(data); - } else { - throw new IllegalArgumentException("Invalid algorithm"); - } - this.algorithm = algorithm; + this.algorithm = algorithm; + } + + public PublicKey(Algorithm algorithm, String hex) { + byte[] data = Utils.hexStringToByteArray(hex); + if (algorithm == Algorithm.Ed25519) { + this.key = Ed25519KeyPair.decode(data); + } else if (algorithm == Algorithm.SECP256R1) { + this.key = SECP256R1KeyPair.decode(data); + } else { + throw new IllegalArgumentException("Invalid algorithm"); } - - public byte[] toBytes() { - if (algorithm == Algorithm.Ed25519) { - return ((EdDSAPublicKey) key).getAbyte(); - } else if (algorithm == Algorithm.SECP256R1) { - return ((BCECPublicKey) key).getQ().getEncoded(true); // true = compressed - } else { - throw new IllegalArgumentException("Invalid algorithm"); - } + this.algorithm = algorithm; + } + + public byte[] toBytes() { + if (this.algorithm == Algorithm.Ed25519) { + return ((EdDSAPublicKey) getKey()).getAbyte(); + } else if (this.algorithm == Algorithm.SECP256R1) { + return ((BCECPublicKey) getKey()).getQ().getEncoded(true); // true = compressed + } else { + throw new IllegalArgumentException("Invalid algorithm"); } - - public String toHex() { - return Utils.byteArrayToHexString(this.toBytes()); + } + + public String toHex() { + return Utils.byteArrayToHexString(this.toBytes()); + } + + public Schema.PublicKey serialize() { + Schema.PublicKey.Builder publicKey = Schema.PublicKey.newBuilder(); + publicKey.setKey(ByteString.copyFrom(this.toBytes())); + publicKey.setAlgorithm(this.algorithm); + return publicKey.build(); + } + + public static PublicKey deserialize(Schema.PublicKey pk) + throws Error.FormatError.DeserializationError { + if (!pk.hasAlgorithm() || !pk.hasKey() || !SUPPORTED_ALGORITHMS.contains(pk.getAlgorithm())) { + throw new Error.FormatError.DeserializationError("Invalid public key"); } - - public PublicKey(Algorithm algorithm, String hex) { - byte[] data = Utils.hexStringToByteArray(hex); - if (algorithm == Algorithm.Ed25519) { - this.key = Ed25519KeyPair.decode(data); - } else if (algorithm == Algorithm.SECP256R1) { - this.key = SECP256R1KeyPair.decode(data); - } else { - throw new IllegalArgumentException("Invalid algorithm"); - } - this.algorithm = algorithm; + return new PublicKey(pk.getAlgorithm(), pk.getKey().toByteArray()); + } + + public static Optional validateSignatureLength(Algorithm algorithm, int length) { + Optional error = Optional.empty(); + if (algorithm == Algorithm.Ed25519) { + if (length != Ed25519KeyPair.SIGNATURE_LENGTH) { + error = Optional.of(new Error.FormatError.Signature.InvalidSignatureSize(length)); + } + } else if (algorithm == Algorithm.SECP256R1) { + if (length < SECP256R1KeyPair.MINIMUM_SIGNATURE_LENGTH + || length > SECP256R1KeyPair.MAXIMUM_SIGNATURE_LENGTH) { + error = Optional.of(new Error.FormatError.Signature.InvalidSignatureSize(length)); + } + } else { + error = + Optional.of(new Error.FormatError.Signature.InvalidSignature("unsupported algorithm")); } + return error; + } - public Schema.PublicKey serialize() { - Schema.PublicKey.Builder publicKey = Schema.PublicKey.newBuilder(); - publicKey.setKey(ByteString.copyFrom(this.toBytes())); - publicKey.setAlgorithm(this.algorithm); - return publicKey.build(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - static public PublicKey deserialize(Schema.PublicKey pk) throws Error.FormatError.DeserializationError { - if(!pk.hasAlgorithm() || !pk.hasKey() || !SUPPORTED_ALGORITHMS.contains(pk.getAlgorithm())) { - throw new Error.FormatError.DeserializationError("Invalid public key"); - } - return new PublicKey(pk.getAlgorithm(), pk.getKey().toByteArray()); + if (o == null || getClass() != o.getClass()) { + return false; } - public static Optional validateSignatureLength(Algorithm algorithm, int length) { - Optional error = Optional.empty(); - if (algorithm == Algorithm.Ed25519) { - if (length != Ed25519KeyPair.SIGNATURE_LENGTH) { - error = Optional.of(new Error.FormatError.Signature.InvalidSignatureSize(length)); - } - } else if (algorithm == Algorithm.SECP256R1) { - if (length < SECP256R1KeyPair.MINIMUM_SIGNATURE_LENGTH || length > SECP256R1KeyPair.MAXIMUM_SIGNATURE_LENGTH) { - error = Optional.of(new Error.FormatError.Signature.InvalidSignatureSize(length)); - } - } else { - error = Optional.of(new Error.FormatError.Signature.InvalidSignature("unsupported algorithm")); - } - return error; - } + PublicKey publicKey = (PublicKey) o; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return this.key.equals(publicKey.getKey()); + } - PublicKey publicKey = (PublicKey) o; + @Override + public int hashCode() { + return getKey().hashCode(); + } - return key.equals(publicKey.key); + @Override + public String toString() { + if (this.algorithm == Algorithm.Ed25519) { + return "ed25519/" + toHex().toLowerCase(); + } else if (this.algorithm == Algorithm.SECP256R1) { + return "secp256r1/" + toHex().toLowerCase(); + } else { + return null; } + } - @Override - public int hashCode() { - return key.hashCode(); - } + public java.security.PublicKey getKey() { + return this.key; + } - @Override - public String toString() { - if (algorithm == Algorithm.Ed25519) { - return "ed25519/" + toHex().toLowerCase(); - } else if (algorithm == Algorithm.SECP256R1) { - return "secp256r1/" + toHex().toLowerCase(); - } else { - return null; - } - } + public Algorithm getAlgorithm() { + return this.algorithm; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/SECP256R1KeyPair.java b/src/main/java/org/biscuitsec/biscuit/crypto/SECP256R1KeyPair.java index dc9222e2..c7d4aa5f 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/SECP256R1KeyPair.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/SECP256R1KeyPair.java @@ -1,6 +1,12 @@ package org.biscuitsec.biscuit.crypto; import biscuit.format.schema.Schema; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.Security; +import java.security.Signature; +import java.security.SignatureException; import org.biscuitsec.biscuit.token.builder.Utils; import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey; import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey; @@ -11,89 +17,89 @@ import org.bouncycastle.jce.spec.ECPublicKeySpec; import org.bouncycastle.util.BigIntegers; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; -import java.security.Security; -import java.security.Signature; -import java.security.SignatureException; - +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") final class SECP256R1KeyPair extends KeyPair { - static final int MINIMUM_SIGNATURE_LENGTH = 68; - static final int MAXIMUM_SIGNATURE_LENGTH = 72; - - private final BCECPrivateKey privateKey; - private final BCECPublicKey publicKey; - - private static final String ALGORITHM = "ECDSA"; - private static final String CURVE = "secp256r1"; - private static final ECNamedCurveParameterSpec SECP256R1 = ECNamedCurveTable.getParameterSpec(CURVE); - - static { - Security.addProvider(new BouncyCastleProvider()); - } - - public SECP256R1KeyPair(byte[] bytes) { - var privateKeySpec = new ECPrivateKeySpec(BigIntegers.fromUnsignedByteArray(bytes), SECP256R1); - var privateKey = new BCECPrivateKey(ALGORITHM, privateKeySpec, BouncyCastleProvider.CONFIGURATION); - - var publicKeySpec = new ECPublicKeySpec(SECP256R1.getG().multiply(privateKeySpec.getD()), SECP256R1); - var publicKey = new BCECPublicKey(ALGORITHM, publicKeySpec, BouncyCastleProvider.CONFIGURATION); - - this.privateKey = privateKey; - this.publicKey = publicKey; - } - - public SECP256R1KeyPair(SecureRandom rng) { - byte[] bytes = new byte[32]; - rng.nextBytes(bytes); - - var privateKeySpec = new ECPrivateKeySpec(BigIntegers.fromUnsignedByteArray(bytes), SECP256R1); - var privateKey = new BCECPrivateKey(ALGORITHM, privateKeySpec, BouncyCastleProvider.CONFIGURATION); - - var publicKeySpec = new ECPublicKeySpec(SECP256R1.getG().multiply(privateKeySpec.getD()), SECP256R1); - var publicKey = new BCECPublicKey(ALGORITHM, publicKeySpec, BouncyCastleProvider.CONFIGURATION); - - this.privateKey = privateKey; - this.publicKey = publicKey; - } - - public SECP256R1KeyPair(String hex) { - this(Utils.hexStringToByteArray(hex)); - } - - public static java.security.PublicKey decode(byte[] data) { - var params = ECNamedCurveTable.getParameterSpec(CURVE); - var spec = new ECPublicKeySpec(params.getCurve().decodePoint(data), params); - return new BCECPublicKey(ALGORITHM, spec, BouncyCastleProvider.CONFIGURATION); - } - - public static Signature getSignature() throws NoSuchAlgorithmException { - return Signature.getInstance("SHA256withECDSA", new BouncyCastleProvider()); - } - - @Override - public byte[] sign(byte[] data) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - Signature sgr = KeyPair.generateSignature(Schema.PublicKey.Algorithm.SECP256R1); - sgr.initSign(privateKey); - sgr.update(data); - return sgr.sign(); - } - - @Override - public byte[] toBytes() { - return BigIntegers.asUnsignedByteArray(privateKey.getD()); - } - - @Override - public String toHex() { - return Utils.byteArrayToHexString(toBytes()); - } - - @Override - public PublicKey public_key() { - return new PublicKey(Schema.PublicKey.Algorithm.SECP256R1, publicKey); - } - + static final int MINIMUM_SIGNATURE_LENGTH = 68; + static final int MAXIMUM_SIGNATURE_LENGTH = 72; + private static final int BUFFER_SIZE = 32; + + private final BCECPrivateKey privateKey; + private final BCECPublicKey publicKey; + + private static final String ALGORITHM = "ECDSA"; + private static final String CURVE = "secp256r1"; + private static final ECNamedCurveParameterSpec SECP256R1 = + ECNamedCurveTable.getParameterSpec(CURVE); + + static { + Security.addProvider(new BouncyCastleProvider()); + } + + SECP256R1KeyPair(byte[] bytes) { + var privateKeySpec = new ECPrivateKeySpec(BigIntegers.fromUnsignedByteArray(bytes), SECP256R1); + var privateKey = + new BCECPrivateKey(ALGORITHM, privateKeySpec, BouncyCastleProvider.CONFIGURATION); + + var publicKeySpec = + new ECPublicKeySpec(SECP256R1.getG().multiply(privateKeySpec.getD()), SECP256R1); + var publicKey = new BCECPublicKey(ALGORITHM, publicKeySpec, BouncyCastleProvider.CONFIGURATION); + + this.privateKey = privateKey; + this.publicKey = publicKey; + } + + SECP256R1KeyPair(SecureRandom rng) { + byte[] bytes = new byte[BUFFER_SIZE]; + rng.nextBytes(bytes); + + var privateKeySpec = new ECPrivateKeySpec(BigIntegers.fromUnsignedByteArray(bytes), SECP256R1); + var privateKey = + new BCECPrivateKey(ALGORITHM, privateKeySpec, BouncyCastleProvider.CONFIGURATION); + + var publicKeySpec = + new ECPublicKeySpec(SECP256R1.getG().multiply(privateKeySpec.getD()), SECP256R1); + var publicKey = new BCECPublicKey(ALGORITHM, publicKeySpec, BouncyCastleProvider.CONFIGURATION); + + this.privateKey = privateKey; + this.publicKey = publicKey; + } + + SECP256R1KeyPair(String hex) { + this(Utils.hexStringToByteArray(hex)); + } + + public static java.security.PublicKey decode(byte[] data) { + var params = ECNamedCurveTable.getParameterSpec(CURVE); + var spec = new ECPublicKeySpec(params.getCurve().decodePoint(data), params); + return new BCECPublicKey(ALGORITHM, spec, BouncyCastleProvider.CONFIGURATION); + } + + public static Signature getSignature() throws NoSuchAlgorithmException { + return Signature.getInstance("SHA256withECDSA", new BouncyCastleProvider()); + } + + @Override + public byte[] sign(byte[] data) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + Signature sgr = KeyPair.generateSignature(Schema.PublicKey.Algorithm.SECP256R1); + sgr.initSign(privateKey); + sgr.update(data); + return sgr.sign(); + } + + @Override + public byte[] toBytes() { + return BigIntegers.asUnsignedByteArray(privateKey.getD()); + } + + @Override + public String toHex() { + return Utils.byteArrayToHexString(toBytes()); + } + + @Override + public PublicKey getPublicKey() { + return new PublicKey(Schema.PublicKey.Algorithm.SECP256R1, publicKey); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/Signer.java b/src/main/java/org/biscuitsec/biscuit/crypto/Signer.java index cf3266ac..1db14dca 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/Signer.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/Signer.java @@ -5,25 +5,26 @@ import java.security.SignatureException; /** - * Interface to enable the cryptographic signature of payload. - * It can be adapted depending of the needs + * Interface to enable the cryptographic signature of payload. It can be adapted depending of the + * needs */ public interface Signer { - /** - * Sign the payload with the signer key - * - * @param payload - * @return the signature of payload by - * @throws NoSuchAlgorithmException - * @throws InvalidKeyException - * @throws SignatureException - */ - public byte[] sign(byte[] payload) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException; + /** + * Sign the payload with the signer key + * + * @param payload + * @return the signature of payload by + * @throws NoSuchAlgorithmException + * @throws InvalidKeyException + * @throws SignatureException + */ + byte[] sign(byte[] payload) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException; - /** - * Return the public key of the signer and the associated algorithm - * - * @return - */ - public PublicKey public_key(); + /** + * Return the public key of the signer and the associated algorithm + * + * @return + */ + PublicKey getPublicKey(); } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/Token.java b/src/main/java/org/biscuitsec/biscuit/crypto/Token.java index 010125c7..63a2eb0b 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/Token.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/Token.java @@ -1,77 +1,84 @@ package org.biscuitsec.biscuit.crypto; -import org.biscuitsec.biscuit.error.Error; -import io.vavr.control.Either; +import static io.vavr.API.Left; +import static io.vavr.API.Right; +import io.vavr.control.Either; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.SignatureException; import java.util.ArrayList; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import org.biscuitsec.biscuit.error.Error; class Token { - public final ArrayList blocks; - public final ArrayList keys; - public final ArrayList signatures; - public final KeyPair next; + private final ArrayList blocks; + private final ArrayList keys; + private final ArrayList signatures; + private final KeyPair next; - public Token(final Signer rootSigner, byte[] message, KeyPair next) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + Token(final Signer rootSigner, byte[] message, KeyPair next) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - byte[] payload = BlockSignatureBuffer.getBufferSignature(next.public_key(), message); + this.blocks = new ArrayList<>(); + this.blocks.add(message); + this.keys = new ArrayList<>(); + this.keys.add(next.getPublicKey()); + this.signatures = new ArrayList<>(); + byte[] payload = BlockSignatureBuffer.getBufferSignature(next.getPublicKey(), message); + byte[] signature = rootSigner.sign(payload); + this.signatures.add(signature); + this.next = next; + } - byte[] signature = rootSigner.sign(payload); + Token( + final ArrayList blocks, + final ArrayList keys, + final ArrayList signatures, + final KeyPair next) { + this.signatures = signatures; + this.blocks = blocks; + this.keys = keys; + this.next = next; + } - this.blocks = new ArrayList<>(); - this.blocks.add(message); - this.keys = new ArrayList<>(); - this.keys.add(next.public_key()); - this.signatures = new ArrayList<>(); - this.signatures.add(signature); - this.next = next; - } + Token append(KeyPair keyPair, byte[] message) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { + byte[] payload = BlockSignatureBuffer.getBufferSignature(keyPair.getPublicKey(), message); + byte[] signature = this.next.sign(payload); - public Token(final ArrayList blocks, final ArrayList keys, final ArrayList signatures, - final KeyPair next) { - this.signatures = signatures; - this.blocks = blocks; - this.keys = keys; - this.next = next; - } + Token token = new Token(this.blocks, this.keys, this.signatures, keyPair); + token.blocks.add(message); + token.signatures.add(signature); + token.keys.add(keyPair.getPublicKey()); - public Token append(KeyPair keyPair, byte[] message) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { - byte[] payload = BlockSignatureBuffer.getBufferSignature(keyPair.public_key(), message); - byte[] signature = this.next.sign(payload); + return token; + } - Token token = new Token(this.blocks, this.keys, this.signatures, keyPair); - token.blocks.add(message); - token.signatures.add(signature); - token.keys.add(keyPair.public_key()); + // FIXME: rust version returns a Result<(), error::Signature> + public Either verify(PublicKey root) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + PublicKey currentKey = root; + for (int i = 0; i < this.blocks.size(); i++) { + byte[] block = this.blocks.get(i); + PublicKey nextKey = this.keys.get(i); + byte[] signature = this.signatures.get(i); - return token; + byte[] payload = BlockSignatureBuffer.getBufferSignature(nextKey, block); + if (KeyPair.verify(currentKey, payload, signature)) { + currentKey = nextKey; + } else { + return Left( + new Error.FormatError.Signature.InvalidSignature( + "signature error: Verification equation was not satisfied")); + } } - // FIXME: rust version returns a Result<(), error::Signature> - public Either verify(PublicKey root) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - PublicKey current_key = root; - for(int i = 0; i < this.blocks.size(); i++) { - byte[] block = this.blocks.get(i); - PublicKey next_key = this.keys.get(i); - byte[] signature = this.signatures.get(i); - - byte[] payload = BlockSignatureBuffer.getBufferSignature(next_key, block); - if (KeyPair.verify(current_key, payload, signature)) { - current_key = next_key; - } else { - return Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")); - } - } - - if (this.next.public_key().equals(current_key)) { - return Right(null); - } else { - return Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")); - } + if (this.next.getPublicKey().equals(currentKey)) { + return Right(null); + } else { + return Left( + new Error.FormatError.Signature.InvalidSignature( + "signature error: Verification equation was not satisfied")); } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/TokenSignature.java b/src/main/java/org/biscuitsec/biscuit/crypto/TokenSignature.java index 2d89ae22..d8cc5941 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/TokenSignature.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/TokenSignature.java @@ -2,184 +2,19 @@ import org.biscuitsec.biscuit.token.builder.Utils; -/** - * Signature aggregation - */ -public class TokenSignature { - /* - final ArrayList parameters; - final Scalar z; +/** Signature aggregation */ +public final class TokenSignature { + private TokenSignature() {} - - public TokenSignature(final SecureRandom rng, KeyPair keypair, byte[] message) { - byte[] b = new byte[64]; - rng.nextBytes(b); - Scalar r = Scalar.fromBytesModOrderWide(b); - - RistrettoElement A = Constants.RISTRETTO_GENERATOR.multiply(r); - ArrayList l = new ArrayList<>(); - l.add(A); - Scalar d = hash_points(l); - Scalar e = hash_message(keypair.public_key, message); - - Scalar z = r.multiply(d).subtract(e.multiply(keypair.private_key)); - - this.parameters = l; - this.z = z; - } - - TokenSignature(final ArrayList parameters, final Scalar z) { - this.parameters = parameters; - this.z = z; - } - - - public TokenSignature sign(final SecureRandom rng, KeyPair keypair, byte[] message) { - byte[] b = new byte[64]; - rng.nextBytes(b); - Scalar r = Scalar.fromBytesModOrderWide(b); - - RistrettoElement A = Constants.RISTRETTO_GENERATOR.multiply(r); - ArrayList l = new ArrayList<>(); - l.add(A); - Scalar d = hash_points(l); - Scalar e = hash_message(keypair.public_key, message); - - Scalar z = r.multiply(d).subtract(e.multiply(keypair.private_key)); - - TokenSignature sig = new TokenSignature(this.parameters, this.z.add(z)); - sig.parameters.add(A); - - return sig; - } - - - public Either verify(List public_keys, List messages) { - if (!(public_keys.size() == messages.size() && public_keys.size() == this.parameters.size())) { - System.out.println(("lists are not the same size")); - return Left(new Error.FormatError.Signature.InvalidFormat()); - } - - //System.out.println("z, zp"); - RistrettoElement zP = Constants.RISTRETTO_GENERATOR.multiply(this.z); - //System.out.println(hex(z.toByteArray())); - //System.out.println(hex(zP.compress().toByteArray())); - - - //System.out.println("eiXi"); - RistrettoElement eiXi = RistrettoElement.IDENTITY; - for(int i = 0; i < public_keys.size(); i++) { - Scalar e = hash_message(public_keys.get(i), messages.get(i)); - //System.out.println(hex(e.toByteArray())); - //System.out.println(hex((public_keys.get(i).multiply(e)).compress().toByteArray())); - - - eiXi = eiXi.add(public_keys.get(i).multiply(e)); - //System.out.println(hex(eiXi.compress().toByteArray())); - - } - - //System.out.println("diAi"); - RistrettoElement diAi = RistrettoElement.IDENTITY; - for (RistrettoElement A: parameters) { - ArrayList l = new ArrayList<>(); - l.add(A); - Scalar d = hash_points(l); - - diAi = diAi.add(A.multiply(d)); - } - - //System.out.println(hex(eiXi.compress().toByteArray())); - //System.out.println(hex(diAi.compress().toByteArray())); - - - - RistrettoElement res = zP.add(eiXi).subtract(diAi); - - //System.out.println(hex(RistrettoElement.IDENTITY.compress().toByteArray())); - //System.out.println(hex(res.compress().toByteArray())); - - if (res.ctEquals(RistrettoElement.IDENTITY) == 1) { - return Right(null); - } else { - return Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")); - } + public static String hex(byte[] byteArray) { + StringBuilder result = new StringBuilder(); + for (byte bb : byteArray) { + result.append(String.format("%02X", bb)); } + return result.toString(); + } - public Schema.Signature serialize() { - Schema.Signature.Builder sig = Schema.Signature.newBuilder() - .setZ(ByteString.copyFrom(this.z.toByteArray())); - - //System.out.println(this.parameters.size()); - for (int i = 0; i < this.parameters.size(); i++) { - //System.out.println(i); - sig.addParameters(ByteString.copyFrom(this.parameters.get(i).compress().toByteArray())); - } - - return sig.build(); - } - - - static public Either deserialize(Schema.Signature sig) { - try { - ArrayList parameters = new ArrayList<>(); - for (ByteString parameter : sig.getParametersList()) { - parameters.add((new CompressedRistretto(parameter.toByteArray())).decompress()); - } - - //System.out.println(hex(sig.getZ().toByteArray())); - //System.out.println(sig.getZ().toByteArray().length); - - Scalar z = Scalar.fromBytesModOrder(sig.getZ().toByteArray()); - - return Right(new TokenSignature(parameters, z)); - } catch (InvalidEncodingException e) { - return Left(new Error.FormatError.Signature.InvalidFormat()); - } catch(IllegalArgumentException e) { - return Left(new Error.FormatError.DeserializationError(e.toString())); - } - } - - static Scalar hash_points(List points) { - try { - MessageDigest digest = MessageDigest.getInstance("SHA-512"); - digest.reset(); - - for (RistrettoElement point : points) { - digest.update(point.compress().toByteArray()); - } - - return Scalar.fromBytesModOrderWide(digest.digest()); - } catch (Exception e) { - e.printStackTrace(); - } - return null; - } - - static Scalar hash_message(RistrettoElement point, byte[] data) { - try { - MessageDigest digest = MessageDigest.getInstance("SHA-512"); - digest.reset(); - - digest.update(point.compress().toByteArray()); - digest.update(data); - - return Scalar.fromBytesModOrderWide(digest.digest()); - } catch (Exception e) { - e.printStackTrace(); - } - return null; - }*/ - - static public String hex(byte[] byteArray) { - StringBuilder result = new StringBuilder(); - for (byte bb : byteArray) { - result.append(String.format("%02X", bb)); - } - return result.toString(); - } - - public static byte[] fromHex(String s) { - return Utils.hexStringToByteArray(s); - } + public static byte[] fromHex(String s) { + return Utils.hexStringToByteArray(s); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/crypto/package-info.java b/src/main/java/org/biscuitsec/biscuit/crypto/package-info.java index 25a7d13c..5cfa7f29 100644 --- a/src/main/java/org/biscuitsec/biscuit/crypto/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/crypto/package-info.java @@ -1,4 +1,2 @@ -/** - * Cryptographic operations for Biscuit tokens - */ -package org.biscuitsec.biscuit.crypto; \ No newline at end of file +/** Cryptographic operations for Biscuit tokens */ +package org.biscuitsec.biscuit.crypto; diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Check.java b/src/main/java/org/biscuitsec/biscuit/datalog/Check.java index 467e61cc..8ab138bd 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Check.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Check.java @@ -1,109 +1,114 @@ package org.biscuitsec.biscuit.datalog; +import static biscuit.format.schema.Schema.CheckV2.Kind.All; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; - import java.util.ArrayList; import java.util.List; import java.util.Objects; +import org.biscuitsec.biscuit.error.Error; -import static biscuit.format.schema.Schema.CheckV2.Kind.All; -import static io.vavr.API.Left; -import static io.vavr.API.Right; - -public class Check { - public enum Kind { - One, - All - } +public final class Check { + public enum Kind { + ONE, + ALL + } - private final Kind kind; + private static final int HASH_CODE_SEED = 31; - private final List queries; + private final Kind kind; - public Check(Kind kind, List queries) { - this.kind = kind; - this.queries = queries; - } + private final List queries; - public Kind kind() { - return kind; - } + public Check(Kind kind, List queries) { + this.kind = kind; + this.queries = queries; + } - public List queries() { - return queries; - } + public Kind kind() { + return kind; + } - public Schema.CheckV2 serialize() { - Schema.CheckV2.Builder b = Schema.CheckV2.newBuilder(); + public List queries() { + return queries; + } - // do not set the kind to One to keep compatibility with older library versions - switch (this.kind) { - case All: - b.setKind(All); - break; - } + public Schema.CheckV2 serialize() { + Schema.CheckV2.Builder b = Schema.CheckV2.newBuilder(); - for(int i = 0; i < this.queries.size(); i++) { - b.addQueries(this.queries.get(i).serialize()); - } + // do not set the kind to One to keep compatibility with older library versions + switch (this.kind) { + case ALL: + b.setKind(All); + break; + default: + } - return b.build(); + for (int i = 0; i < this.queries.size(); i++) { + b.addQueries(this.queries.get(i).serialize()); } - static public Either deserializeV2(Schema.CheckV2 check) { - ArrayList queries = new ArrayList<>(); - - Kind kind; - switch (check.getKind()) { - case One: - kind = Kind.One; - break; - case All: - kind = Kind.All; - break; - default: - kind = Kind.One; - break; - } - - for (Schema.RuleV2 query: check.getQueriesList()) { - Either res = Rule.deserializeV2(query); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - queries.add(res.get()); - } - } - - return Right(new Check(kind, queries)); + return b.build(); + } + + public static Either deserializeV2(Schema.CheckV2 check) { + ArrayList queries = new ArrayList<>(); + + Kind kind; + switch (check.getKind()) { + case One: + kind = Kind.ONE; + break; + case All: + kind = Kind.ALL; + break; + default: + kind = Kind.ONE; + break; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + for (Schema.RuleV2 query : check.getQueriesList()) { + Either res = Rule.deserializeV2(query); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + queries.add(res.get()); + } + } - Check check = (Check) o; + return Right(new Check(kind, queries)); + } - if (kind != check.kind) return false; - return Objects.equals(queries, check.queries); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - int result = kind != null ? kind.hashCode() : 0; - result = 31 * result + (queries != null ? queries.hashCode() : 0); - return result; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public String toString() { - return "Check{" + - "kind=" + kind + - ", queries=" + queries + - '}'; + Check check = (Check) o; + + if (kind != check.kind) { + return false; } + return Objects.equals(queries, check.queries); + } + + @Override + public int hashCode() { + int result = kind != null ? kind.hashCode() : 0; + result = HASH_CODE_SEED * result + (queries != null ? queries.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "Check{kind=" + kind + ", queries=" + queries + '}'; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Combinator.java b/src/main/java/org/biscuitsec/biscuit/datalog/Combinator.java index 0a3ff42d..7dc8f8e4 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Combinator.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Combinator.java @@ -2,151 +2,162 @@ import io.vavr.Tuple2; import io.vavr.control.Option; - import java.io.Serializable; -import java.util.*; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; import java.util.function.Supplier; import java.util.stream.Stream; public final class Combinator implements Serializable, Iterator>> { - private MatchedVariables variables; - private final Supplier>> allFacts; - private final List predicates; - private final Iterator> currentFacts; - private Combinator currentIt; - private final SymbolTable symbols; - - private Origin currentOrigin; - - private Option>> nextElement; - - @Override - public boolean hasNext() { - if (this.nextElement != null && this.nextElement.isDefined()) { - return true; - } + private MatchedVariables variables; + private final Supplier>> allFacts; + private final List predicates; + private final Iterator> currentFacts; + private Combinator currentIt; + private final SymbolTable symbolTable; + + private Origin currentOrigin; + + private Option>> nextElement; + + @Override + public boolean hasNext() { + if (this.nextElement != null && this.nextElement.isDefined()) { + return true; + } + this.nextElement = getNext(); + return this.nextElement.isDefined(); + } + + @Override + public Tuple2> next() { + if (this.nextElement == null || !this.nextElement.isDefined()) { this.nextElement = getNext(); - return this.nextElement.isDefined(); - } - - @Override - public Tuple2> next() { - if (this.nextElement == null || !this.nextElement.isDefined()) { - this.nextElement = getNext(); - } - if (this.nextElement == null || !this.nextElement.isDefined()) { - throw new NoSuchElementException(); + } + if (this.nextElement == null || !this.nextElement.isDefined()) { + throw new NoSuchElementException(); + } else { + Tuple2> t = this.nextElement.get(); + this.nextElement = Option.none(); + return t; + } + } + + public Option>> getNext() { + if (this.predicates.isEmpty()) { + final Option> vOpt = this.variables.complete(); + if (vOpt.isEmpty()) { + return Option.none(); } else { - Tuple2> t = this.nextElement.get(); - this.nextElement = Option.none(); - return t; - } - } - - public Option>> getNext() { - if (this.predicates.isEmpty()) { - final Option> v_opt = this.variables.complete(); - if (v_opt.isEmpty()) { - return Option.none(); - } else { - Map variables = v_opt.get(); - // if there were no predicates, - // we should return a value, but only once. To prevent further - // successful calls, we create a set of variables that cannot - // possibly be completed, so the next call will fail - Set set = new HashSet<>(); - set.add((long) 0); - - this.variables = new MatchedVariables(set); - return Option.some(new Tuple2<>(new Origin(), variables)); - } + Map variables = vOpt.get(); + // if there were no predicates, + // we should return a value, but only once. To prevent further + // successful calls, we create a set of variables that cannot + // possibly be completed, so the next call will fail + Set set = new HashSet<>(); + set.add((long) 0); + + this.variables = new MatchedVariables(set); + return Option.some(new Tuple2<>(new Origin(), variables)); } - - while (true) { - if (this.currentIt == null) { - Predicate predicate = this.predicates.get(0); - - while (true) { - // we iterate over the facts that match the current predicate - if (this.currentFacts.hasNext()) { - final Tuple2 t = this.currentFacts.next(); - Origin currentOrigin = t._1.clone(); - Fact fact = t._2; - - // create a new MatchedVariables in which we fix variables we could unify from our first predicate and the current fact - MatchedVariables vars = this.variables.clone(); - boolean matchTerms = true; - - // we know the fact matches the predicate's format so they have the same number of terms - // fill the MatchedVariables before creating the next combinator - for (int i = 0; i < predicate.terms().size(); ++i) { - final Term term = predicate.terms().get(i); - if (term instanceof Term.Variable) { - final long key = ((Term.Variable) term).value(); - final Term value = fact.predicate().terms().get(i); - - if (!vars.insert(key, value)) { - matchTerms = false; - } - if (!matchTerms) { - break; - } - } - } - - // the fact did not match the predicate, try the next one - if (!matchTerms) { - continue; - } - - // there are no more predicates to check - if (this.predicates.size() == 1) { - final Option> v_opt = vars.complete(); - if (v_opt.isEmpty()) { - continue; - } else { - return Option.some(new Tuple2<>(currentOrigin, v_opt.get())); - } - } else { - this.currentOrigin = currentOrigin; - // we found a matching fact, we create a new combinator over the rest of the predicates - // no need to copy all the expressions at all levels - this.currentIt = new Combinator(vars, predicates.subList(1, predicates.size()), this.allFacts, this.symbols); - } + } + + while (true) { + if (this.currentIt == null) { + Predicate predicate = this.predicates.get(0); + + while (true) { + // we iterate over the facts that match the current predicate + if (this.currentFacts.hasNext()) { + final Tuple2 t = this.currentFacts.next(); + Origin currentOrigin = t._1.clone(); + Fact fact = t._2; + + // create a new MatchedVariables in which we fix variables we could unify from our first + // predicate and the current fact + MatchedVariables vars = this.variables.clone(); + boolean matchTerms = true; + + // we know the fact matches the predicate's format so they have the same number of terms + // fill the MatchedVariables before creating the next combinator + for (int i = 0; i < predicate.terms().size(); ++i) { + final Term term = predicate.terms().get(i); + if (term instanceof Term.Variable) { + final long key = ((Term.Variable) term).value(); + final Term value = fact.predicate().terms().get(i); + + if (!vars.insert(key, value)) { + matchTerms = false; + } + if (!matchTerms) { break; + } + } + } + + // the fact did not match the predicate, try the next one + if (!matchTerms) { + continue; + } - } else { - return Option.none(); - } + // there are no more predicates to check + if (this.predicates.size() == 1) { + final Option> vOpt = vars.complete(); + if (vOpt.isEmpty()) { + continue; + } else { + return Option.some(new Tuple2<>(currentOrigin, vOpt.get())); + } + } else { + this.currentOrigin = currentOrigin; + // we found a matching fact, we create a new combinator over the rest of the + // predicates + // no need to copy all the expressions at all levels + this.currentIt = + new Combinator( + vars, predicates.subList(1, predicates.size()), this.allFacts, this.symbolTable); } - } + break; - if (this.currentIt == null) { + } else { return Option.none(); - } + } + } + } + + if (this.currentIt == null) { + return Option.none(); + } - Option>> opt = this.currentIt.getNext(); + Option>> opt = this.currentIt.getNext(); - if (opt.isDefined()) { - Tuple2> t = opt.get(); - return Option.some(new Tuple2<>(t._1.union(currentOrigin), t._2)); - } else { - currentOrigin = null; - currentIt = null; - } + if (opt.isDefined()) { + Tuple2> t = opt.get(); + return Option.some(new Tuple2<>(t._1.union(currentOrigin), t._2)); + } else { + currentOrigin = null; + currentIt = null; } - } - - - public Combinator(final MatchedVariables variables, final List predicates, - Supplier>> all_facts, final SymbolTable symbols) { - this.variables = variables; - this.allFacts = all_facts; - this.currentIt = null; - this.predicates = predicates; - this.currentFacts = all_facts.get().filter((tuple) -> tuple._2.match_predicate(predicates.get(0))).iterator(); - this.symbols = symbols; - this.currentOrigin = null; - this.nextElement = null; - } + } + } + + public Combinator( + final MatchedVariables variables, + final List predicates, + Supplier>> allFacts, + final SymbolTable symbolTable) { + this.variables = variables; + this.allFacts = allFacts; + this.currentIt = null; + this.predicates = predicates; + this.currentFacts = + allFacts.get().filter((tuple) -> tuple._2.matchPredicate(predicates.get(0))).iterator(); + this.symbolTable = symbolTable; + this.currentOrigin = null; + this.nextElement = null; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Fact.java b/src/main/java/org/biscuitsec/biscuit/datalog/Fact.java index 4bc73d64..bb6c105d 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Fact.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Fact.java @@ -1,66 +1,67 @@ package org.biscuitsec.biscuit.datalog; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; - import java.io.Serializable; import java.util.List; import java.util.Objects; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import org.biscuitsec.biscuit.error.Error; public final class Fact implements Serializable { - private final Predicate predicate; + private final Predicate predicate; - public final Predicate predicate() { - return this.predicate; - } + public Predicate predicate() { + return this.predicate; + } - public boolean match_predicate(final Predicate rule_predicate) { - return this.predicate.match(rule_predicate); - } + public boolean matchPredicate(final Predicate rulePredicate) { + return this.predicate.match(rulePredicate); + } - public Fact(final Predicate predicate) { - this.predicate = predicate; - } + public Fact(final Predicate predicate) { + this.predicate = predicate; + } - public Fact(final long name, final List terms){ - this.predicate = new Predicate(name, terms); - } + public Fact(final long name, final List terms) { + this.predicate = new Predicate(name, terms); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Fact fact = (Fact) o; - return Objects.equals(predicate, fact.predicate); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Fact fact = (Fact) o; + return Objects.equals(predicate, fact.predicate); + } - @Override - public int hashCode() { - return Objects.hash(predicate); - } + @Override + public int hashCode() { + return Objects.hash(predicate); + } - @Override - public String toString() { - return this.predicate.toString(); - } + @Override + public String toString() { + return this.predicate.toString(); + } - public Schema.FactV2 serialize() { - return Schema.FactV2.newBuilder() - .setPredicate(this.predicate.serialize()) - .build(); - } + public Schema.FactV2 serialize() { + return Schema.FactV2.newBuilder().setPredicate(this.predicate.serialize()).build(); + } - static public Either deserializeV2(Schema.FactV2 fact) { - Either res = Predicate.deserializeV2(fact.getPredicate()); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - return Right(new Fact(res.get())); - } - } + public static Either deserializeV2(Schema.FactV2 fact) { + Either res = Predicate.deserializeV2(fact.getPredicate()); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + return Right(new Fact(res.get())); + } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/FactSet.java b/src/main/java/org/biscuitsec/biscuit/datalog/FactSet.java index d1c4b26c..bc10e045 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/FactSet.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/FactSet.java @@ -1,110 +1,110 @@ package org.biscuitsec.biscuit.datalog; import io.vavr.Tuple2; - -import java.util.*; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.stream.Stream; -public class FactSet { - private final HashMap> facts; +public final class FactSet { + private final HashMap> facts; - public FactSet() { - facts = new HashMap<>(); - } + public FactSet() { + facts = new HashMap<>(); + } - public FactSet(Origin o, HashSet factSet) { - facts = new HashMap<>(); - facts.put(o, factSet); - } + public FactSet(Origin o, HashSet factSet) { + facts = new HashMap<>(); + facts.put(o, factSet); + } - public HashMap> facts() { - return this.facts; - } + public HashMap> facts() { + return this.facts; + } - public void add(Origin origin, Fact fact) { - if(!facts.containsKey(origin)) { - facts.put(origin, new HashSet<>()); - } - facts.get(origin).add(fact); + public void add(Origin origin, Fact fact) { + if (!facts.containsKey(origin)) { + facts.put(origin, new HashSet<>()); } + facts.get(origin).add(fact); + } - public int size() { - int size = 0; - for(HashSet h: facts.values()) { - size += h.size(); - } - - return size; + public int size() { + int size = 0; + for (HashSet h : facts.values()) { + size += h.size(); } - public FactSet clone() { - FactSet newFacts = new FactSet(); + return size; + } - for(Map.Entry> entry: this.facts.entrySet()) { - HashSet h = new HashSet<>(entry.getValue()); - newFacts.facts.put(entry.getKey(), h); - } + public FactSet clone() { + FactSet newFacts = new FactSet(); - return newFacts; + for (Map.Entry> entry : this.facts.entrySet()) { + HashSet h = new HashSet<>(entry.getValue()); + newFacts.facts.put(entry.getKey(), h); } - public void merge(FactSet other) { - for(Map.Entry> entry: other.facts.entrySet()) { - if(!facts.containsKey(entry.getKey())) { - facts.put(entry.getKey(), entry.getValue()); - } else { - facts.get(entry.getKey()).addAll(entry.getValue()); - } - } - } - public Stream stream(TrustedOrigins blockIds) { - return facts.entrySet() - .stream() - .filter(entry -> { - Origin o = entry.getKey(); - return blockIds.contains(o); - }) - .flatMap(entry -> entry.getValue() - .stream() - .map(fact -> new Tuple2<>(entry.getKey(), fact))); - } + return newFacts; + } - public Stream stream() { - return facts.entrySet() - .stream() - .flatMap(entry -> entry.getValue() - .stream()); + public void merge(FactSet other) { + for (Map.Entry> entry : other.facts.entrySet()) { + if (!facts.containsKey(entry.getKey())) { + facts.put(entry.getKey(), entry.getValue()); + } else { + facts.get(entry.getKey()).addAll(entry.getValue()); + } } - - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - FactSet factSet = (FactSet) o; - - return facts.equals(factSet.facts); + } + + public Stream stream(TrustedOrigins blockIds) { + return facts.entrySet().stream() + .filter( + entry -> { + Origin o = entry.getKey(); + return blockIds.contains(o); + }) + .flatMap( + entry -> entry.getValue().stream().map(fact -> new Tuple2<>(entry.getKey(), fact))); + } + + public Stream stream() { + return facts.entrySet().stream().flatMap(entry -> entry.getValue().stream()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return facts.hashCode(); + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public String toString() { - StringBuilder res = new StringBuilder("FactSet {"); - for(Map.Entry> entry: this.facts.entrySet()) { - res.append("\n\t").append(entry.getKey()).append("["); - for(Fact fact: entry.getValue()) { - res.append("\n\t\t").append(fact); - } - res.append("\n]"); - } - res.append("\n}"); - - return res.toString(); + FactSet factSet = (FactSet) o; + + return facts.equals(factSet.facts); + } + + @Override + public int hashCode() { + return facts.hashCode(); + } + + @Override + public String toString() { + StringBuilder res = new StringBuilder("FactSet {"); + for (Map.Entry> entry : this.facts.entrySet()) { + res.append("\n\t").append(entry.getKey()).append("["); + for (Fact fact : entry.getValue()) { + res.append("\n\t\t").append(fact); + } + res.append("\n]"); } -} + res.append("\n}"); + return res.toString(); + } +} diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/MatchedVariables.java b/src/main/java/org/biscuitsec/biscuit/datalog/MatchedVariables.java index 3ecaeb35..6f55d192 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/MatchedVariables.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/MatchedVariables.java @@ -1,86 +1,89 @@ package org.biscuitsec.biscuit.datalog; -import org.biscuitsec.biscuit.datalog.expressions.Expression; import io.vavr.control.Option; -import org.biscuitsec.biscuit.error.Error; - import java.io.Serializable; -import java.util.*; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.error.Error; public final class MatchedVariables implements Serializable { - private final Map> variables; + private final Map> variables; - public boolean insert(final long key, final Term value) { - if (this.variables.containsKey(key)) { - final Optional val = this.variables.get(key); - if (val.isPresent()) { - return val.get().equals(value); - } else { - this.variables.put(key, Optional.of(value)); - return true; - } + public boolean insert(final long key, final Term value) { + if (this.variables.containsKey(key)) { + final Optional val = this.variables.get(key); + if (val.isPresent()) { + return val.get().equals(value); } else { - return false; + this.variables.put(key, Optional.of(value)); + return true; } - } + } else { + return false; + } + } - public Optional get(final long key) { - return this.variables.get(key); - } + public Optional get(final long key) { + return this.variables.get(key); + } - public boolean is_complete() { - return this.variables.values().stream().allMatch((v) -> v.isPresent()); - } + public boolean isComplete() { + return this.variables.values().stream().allMatch((v) -> v.isPresent()); + } - public Option> complete() { - final Map variables = new HashMap<>(); - for (final Map.Entry> entry : this.variables.entrySet()) { - if (entry.getValue().isPresent()) { - variables.put(entry.getKey(), entry.getValue().get()); - } else { - return Option.none(); - } - } - return Option.some(variables); - } - - public MatchedVariables clone() { - final MatchedVariables other = new MatchedVariables(this.variables.keySet()); - for (final Map.Entry> entry : this.variables.entrySet()) { - if (entry.getValue().isPresent()) { - other.variables.put(entry.getKey(), entry.getValue()); - } + public Option> complete() { + final Map variables = new HashMap<>(); + for (final Map.Entry> entry : this.variables.entrySet()) { + if (entry.getValue().isPresent()) { + variables.put(entry.getKey(), entry.getValue().get()); + } else { + return Option.none(); } - return other; - } + } + return Option.some(variables); + } - public MatchedVariables(final Set ids) { - this.variables = new HashMap<>(); - for (final Long id : ids) { - this.variables.put(id, Optional.empty()); + public MatchedVariables clone() { + final MatchedVariables other = new MatchedVariables(this.variables.keySet()); + for (final Map.Entry> entry : this.variables.entrySet()) { + if (entry.getValue().isPresent()) { + other.variables.put(entry.getKey(), entry.getValue()); } - } - - public Option> check_expressions(List expressions, SymbolTable symbols) throws Error { - final Option> vars = this.complete(); - if (vars.isDefined()) { - Map variables = vars.get(); + } + return other; + } + public MatchedVariables(final Set ids) { + this.variables = new HashMap<>(); + for (final Long id : ids) { + this.variables.put(id, Optional.empty()); + } + } - for(Expression e: expressions) { - Term term = e.evaluate(variables, new TemporarySymbolTable(symbols)); + public Option> checkExpressions(List expressions, SymbolTable symbolTable) + throws Error { + final Option> vars = this.complete(); + if (vars.isDefined()) { + Map variables = vars.get(); - if(! (term instanceof Term.Bool)) { - throw new Error.InvalidType(); - } - if(!term.equals(new Term.Bool(true))) { - return Option.none(); - } - } + for (Expression e : expressions) { + Term term = e.evaluate(variables, new TemporarySymbolTable(symbolTable)); - return Option.some(variables); - } else { - return Option.none(); + if (!(term instanceof Term.Bool)) { + throw new Error.InvalidType(); + } + if (!term.equals(new Term.Bool(true))) { + return Option.none(); + } } - } + + return Option.some(variables); + } else { + return Option.none(); + } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Origin.java b/src/main/java/org/biscuitsec/biscuit/datalog/Origin.java index 5c49d117..31c19419 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Origin.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Origin.java @@ -1,68 +1,86 @@ package org.biscuitsec.biscuit.datalog; -import java.util.*; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; -public class Origin { - public HashSet inner; +public final class Origin { + private final HashSet blockIds; - public Origin() { - inner = new HashSet<>(); - } + public Origin() { + this.blockIds = new HashSet<>(); + } - private Origin(HashSet inner) { - this.inner = inner; - } + public Origin(Long i) { + this.blockIds = new HashSet<>(); + this.blockIds.add(i); + } - public Origin(Long i) { - this.inner = new HashSet<>(); - this.inner.add(i); - } + public Origin(int i) { + this.blockIds = new HashSet<>(); + this.blockIds.add((long) i); + } - public Origin(int i) { - this.inner = new HashSet<>(); - this.inner.add((long)i); - } + public static Origin authorizer() { + return new Origin(Long.MAX_VALUE); + } - public static Origin authorizer() { - return new Origin(Long.MAX_VALUE); - } - public void add(int i) { - inner.add((long) i); - } - public void add(long i) { - inner.add(i); - } + public void add(int i) { + blockIds.add((long) i); + } - public Origin union(Origin other) { - Origin o = this.clone(); - o.inner.addAll(other.inner); - return o; - } + public void add(long i) { + blockIds.add(i); + } - public Origin clone() { - final HashSet newInner = new HashSet<>(this.inner); - return new Origin(newInner); - } + public boolean addAll(final Collection newBlockIds) { + return this.blockIds.addAll(newBlockIds); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public Origin union(Origin other) { + Origin o = this.clone(); + o.blockIds.addAll(other.blockIds); + return o; + } - Origin origin = (Origin) o; + public boolean containsAll(Origin other) { + return this.blockIds.containsAll(other.blockIds); + } - return Objects.equals(inner, origin.inner); - } + @Override + protected Origin clone() { + final Origin newOrigin = new Origin(); + newOrigin.addAll(this.blockIds); + return newOrigin; + } - @Override - public int hashCode() { - return inner != null ? inner.hashCode() : 0; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return "Origin{" + - "inner=" + inner + - '}'; + if (o == null || getClass() != o.getClass()) { + return false; } + + Origin origin = (Origin) o; + + return Objects.equals(blockIds, origin.blockIds); + } + + @Override + public int hashCode() { + return blockIds != null ? blockIds.hashCode() : 0; + } + + @Override + public String toString() { + return "Origin{inner=" + blockIds + '}'; + } + + public Set blockIds() { + return Collections.unmodifiableSet(this.blockIds); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Predicate.java b/src/main/java/org/biscuitsec/biscuit/datalog/Predicate.java index 1ff05650..5d774d75 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Predicate.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Predicate.java @@ -1,103 +1,111 @@ package org.biscuitsec.biscuit.datalog; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.error.Error; import io.vavr.control.Either; - import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.ListIterator; import java.util.Objects; import java.util.stream.Collectors; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import org.biscuitsec.biscuit.error.Error; public final class Predicate implements Serializable { - private final long name; - private final List terms; - - public long name() { - return this.name; - } - - public final List terms() { - return this.terms; - } - - public final ListIterator ids_iterator() { - return this.terms.listIterator(); - } - - public boolean match(final Predicate rule_predicate) { - if (this.name != rule_predicate.name) { - return false; - } - if (this.terms.size() != rule_predicate.terms.size()) { - return false; - } - for (int i = 0; i < this.terms.size(); ++i) { - if (!this.terms.get(i).match(rule_predicate.terms.get(i))) { - return false; - } + private final long name; + private final List terms; + + public long name() { + return this.name; + } + + public List terms() { + return this.terms; + } + + public ListIterator idsIterator() { + return this.terms.listIterator(); + } + + public boolean match(final Predicate rulePredicate) { + if (this.name != rulePredicate.name) { + return false; + } + if (this.terms.size() != rulePredicate.terms.size()) { + return false; + } + for (int i = 0; i < this.terms.size(); ++i) { + if (!this.terms.get(i).match(rulePredicate.terms.get(i))) { + return false; } + } + return true; + } + + public Predicate clone() { + final List terms = new ArrayList<>(); + terms.addAll(this.terms); + return new Predicate(this.name, terms); + } + + public Predicate(final long name, final List terms) { + this.name = name; + this.terms = terms; + } + + @Override + public boolean equals(Object o) { + if (this == o) { return true; - } - - public Predicate clone() { - final List terms = new ArrayList<>(); - terms.addAll(this.terms); - return new Predicate(this.name, terms); - } - - public Predicate(final long name, final List terms) { - this.name = name; - this.terms = terms; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Predicate predicate = (Predicate) o; - return name == predicate.name && - Objects.equals(terms, predicate.terms); - } - - @Override - public int hashCode() { - return Objects.hash(name, terms); - } - - @Override - public String toString() { - return this.name + "(" + String.join(", ", this.terms.stream().map((i) -> (i == null) ? "(null)" : i.toString()).collect(Collectors.toList())) + ")"; - } - - public Schema.PredicateV2 serialize() { - Schema.PredicateV2.Builder builder = Schema.PredicateV2.newBuilder() - .setName(this.name); - - for (int i = 0; i < this.terms.size(); i++) { - builder.addTerms(this.terms.get(i).serialize()); - } - - return builder.build(); - } - - static public Either deserializeV2(Schema.PredicateV2 predicate) { - ArrayList terms = new ArrayList<>(); - for (Schema.TermV2 id: predicate.getTermsList()) { - Either res = Term.deserialize_enumV2(id); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - terms.add(res.get()); - } + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Predicate predicate = (Predicate) o; + return name == predicate.name && Objects.equals(terms, predicate.terms); + } + + @Override + public int hashCode() { + return Objects.hash(name, terms); + } + + @Override + public String toString() { + return this.name + + "(" + + String.join( + ", ", + this.terms.stream() + .map((i) -> (i == null) ? "(null)" : i.toString()) + .collect(Collectors.toList())) + + ")"; + } + + public Schema.PredicateV2 serialize() { + Schema.PredicateV2.Builder builder = Schema.PredicateV2.newBuilder().setName(this.name); + + for (int i = 0; i < this.terms.size(); i++) { + builder.addTerms(this.terms.get(i).serialize()); + } + + return builder.build(); + } + + public static Either deserializeV2(Schema.PredicateV2 predicate) { + ArrayList terms = new ArrayList<>(); + for (Schema.TermV2 id : predicate.getTermsList()) { + Either res = Term.deserializeEnumV2(id); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + terms.add(res.get()); } + } - return Right(new Predicate(predicate.getName(), terms)); - } + return Right(new Predicate(predicate.getName(), terms)); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Rule.java b/src/main/java/org/biscuitsec/biscuit/datalog/Rule.java index 1fa24b0a..6ae5e4ea 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Rule.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Rule.java @@ -1,277 +1,315 @@ package org.biscuitsec.biscuit.datalog; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.datalog.expressions.Expression; -import org.biscuitsec.biscuit.error.Error; import io.vavr.Tuple2; import io.vavr.Tuple3; import io.vavr.control.Either; - import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.Spliterator; +import java.util.Spliterators; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.error.Error; public final class Rule implements Serializable { - private final Predicate head; - private final List body; - private final List expressions; - private final List scopes; - - public final Predicate head() { - return this.head; - } - - public final List body() { - return this.body; - } - - public final List expressions() { - return this.expressions; - } - - public List scopes() { - return scopes; - } - - public Stream>> apply( - final Supplier>> factsSupplier, Long ruleOrigin, SymbolTable symbols) { - MatchedVariables variables = variablesSet(); - - Combinator combinator = new Combinator(variables, this.body, factsSupplier, symbols); - Spliterator>> splitItr = Spliterators - .spliteratorUnknownSize(combinator, Spliterator.ORDERED); - Stream>> stream = StreamSupport.stream(splitItr, false); - - //somehow we have inference errors when writing this as a lambda - return stream.map(t -> { - Origin origin = t._1; - Map generatedVariables = t._2; - TemporarySymbolTable temporarySymbols = new TemporarySymbolTable(symbols); - for (Expression e : this.expressions) { - try { - Term term = e.evaluate(generatedVariables, temporarySymbols); - - if (term instanceof Term.Bool) { - Term.Bool b = (Term.Bool) term; - if (!b.value()) { - return Either.right(new Tuple3<>(origin, generatedVariables, false)); - } - // continue evaluating if true - } else { - return Either.left(new Error.InvalidType()); - } - } catch(Error error) { - return Either.left(error); + private final Predicate head; + private final List body; + private final List expressions; + private final List scopes; + + public Predicate head() { + return this.head; + } + + public List body() { + return this.body; + } + + public List expressions() { + return this.expressions; + } + + public List scopes() { + return scopes; + } + + public Stream>> apply( + final Supplier>> factsSupplier, + Long ruleOrigin, + SymbolTable symbolTable) { + MatchedVariables variables = variablesSet(); + + Combinator combinator = new Combinator(variables, this.body, factsSupplier, symbolTable); + Spliterator>> splitItr = + Spliterators.spliteratorUnknownSize(combinator, Spliterator.ORDERED); + Stream>> stream = StreamSupport.stream(splitItr, false); + + // somehow we have inference errors when writing this as a lambda + return stream + .map( + t -> { + Origin origin = t._1; + Map generatedVariables = t._2; + TemporarySymbolTable temporarySymbols = new TemporarySymbolTable(symbolTable); + for (Expression e : this.expressions) { + try { + Term term = e.evaluate(generatedVariables, temporarySymbols); + + if (term instanceof Term.Bool) { + Term.Bool b = (Term.Bool) term; + if (!b.value()) { + return Either.right(new Tuple3<>(origin, generatedVariables, false)); } - - } - return Either.right(new Tuple3<>(origin, generatedVariables, true)); - }) - // sometimes we need to make the compiler happy - .filter((java.util.function.Predicate>) - res -> res.isRight() & ((Tuple3, Boolean>) res.get())._3).map(res -> { - Tuple3, Boolean> t = (Tuple3, Boolean>) res.get(); - Origin origin = t._1; - Map generatedVariables = t._2; - - Predicate p = this.head.clone(); - for (int index = 0; index < p.terms().size(); index++) { - if (p.terms().get(index) instanceof Term.Variable) { - Term.Variable var = (Term.Variable) p.terms().get(index); - if (!generatedVariables.containsKey(var.value())) { - //throw new Error("variables that appear in the head should appear in the body as well"); - return Either.left(new Error.InternalError()); - } - p.terms().set(index, generatedVariables.get(var.value())); - } - } - - origin.add(ruleOrigin); - return Either.right(new Tuple2(origin, new Fact(p))); - }); - } - - private MatchedVariables variablesSet() { - final Set variables_set = new HashSet<>(); - - for (final Predicate pred : this.body) { - variables_set.addAll(pred.terms().stream().filter((id) -> id instanceof Term.Variable).map((id) -> ((Term.Variable) id).value()).collect(Collectors.toSet())); - } - return new MatchedVariables(variables_set); - } - - // do not produce new facts, only find one matching set of facts - public boolean find_match(final FactSet facts, Long origin, TrustedOrigins scope, SymbolTable symbols) throws Error { - MatchedVariables variables = variablesSet(); - - if(this.body.isEmpty()) { - return variables.check_expressions(this.expressions, symbols).isDefined(); - } - - Supplier>> factsSupplier = () -> facts.stream(scope); - Stream>> stream = this.apply(factsSupplier, origin, symbols); - - Iterator>> it = stream.iterator(); - - if(!it.hasNext()) { - return false; - } - - Either> next = it.next(); - if(next.isRight()) { - return true; - } else { - throw next.getLeft(); - } - } - - // verifies that the expressions return true for every matching set of facts - public boolean check_match_all(final FactSet facts, TrustedOrigins scope, SymbolTable symbols) throws Error { - MatchedVariables variables = variablesSet(); - - if(this.body.isEmpty()) { - return variables.check_expressions(this.expressions, symbols).isDefined(); - } - - Supplier>> factsSupplier = () -> facts.stream(scope); - Combinator combinator = new Combinator(variables, this.body, factsSupplier, symbols); - boolean found = false; - - for (Combinator it = combinator; it.hasNext(); ) { - Tuple2> t = it.next(); - Map generatedVariables = t._2; - found = true; - - TemporarySymbolTable temporarySymbols = new TemporarySymbolTable(symbols); - for (Expression e : this.expressions) { - - Term term = e.evaluate(generatedVariables, temporarySymbols); - if (term instanceof Term.Bool) { - Term.Bool b = (Term.Bool) term; - if (!b.value()) { - return false; - } - // continue evaluating if true - } else { - throw new Error.InvalidType(); + // continue evaluating if true + } else { + return Either.left(new Error.InvalidType()); + } + } catch (Error error) { + return Either.left(error); + } + } + return Either.right(new Tuple3<>(origin, generatedVariables, true)); + }) + // sometimes we need to make the compiler happy + .filter( + (java.util.function.Predicate>) + res -> res.isRight() & ((Tuple3, Boolean>) res.get())._3) + .map( + res -> { + Tuple3, Boolean> t = + (Tuple3, Boolean>) res.get(); + Origin origin = t._1; + Map generatedVariables = t._2; + + Predicate p = this.head.clone(); + for (int index = 0; index < p.terms().size(); index++) { + if (p.terms().get(index) instanceof Term.Variable) { + Term.Variable var = (Term.Variable) p.terms().get(index); + if (!generatedVariables.containsKey(var.value())) { + // throw new Error("variables that appear in the head should appear in the body + // as well"); + return Either.left(new Error.InternalError()); + } + p.terms().set(index, generatedVariables.get(var.value())); + } } - } - } - return found; - } - - public Rule(final Predicate head, final List body, final List expressions) { - this.head = head; - this.body = body; - this.expressions = expressions; - this.scopes = new ArrayList<>(); - } - - public Rule(final Predicate head, final List body, final List expressions, - final List scopes) { - this.head = head; - this.body = body; - this.expressions = expressions; - this.scopes = scopes; - } - - public Schema.RuleV2 serialize() { - Schema.RuleV2.Builder b = Schema.RuleV2.newBuilder() - .setHead(this.head.serialize()); - - for (int i = 0; i < this.body.size(); i++) { - b.addBody(this.body.get(i).serialize()); - } - - for (int i = 0; i < this.expressions.size(); i++) { - b.addExpressions(this.expressions.get(i).serialize()); - } - - for (Scope scope: this.scopes) { - b.addScope(scope.serialize()); - } - return b.build(); - } - - static public Either deserializeV2(Schema.RuleV2 rule) { - ArrayList body = new ArrayList<>(); - for (Schema.PredicateV2 predicate: rule.getBodyList()) { - Either res = Predicate.deserializeV2(predicate); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - body.add(res.get()); - } + origin.add(ruleOrigin); + return Either.right(new Tuple2(origin, new Fact(p))); + }); + } + + private MatchedVariables variablesSet() { + final Set variablesSet = new HashSet<>(); + + for (final Predicate pred : this.body) { + variablesSet.addAll( + pred.terms().stream() + .filter((id) -> id instanceof Term.Variable) + .map((id) -> ((Term.Variable) id).value()) + .collect(Collectors.toSet())); + } + return new MatchedVariables(variablesSet); + } + + // do not produce new facts, only find one matching set of facts + public boolean findMatch( + final FactSet facts, Long origin, TrustedOrigins scope, SymbolTable symbolTable) throws Error { + MatchedVariables variables = variablesSet(); + + if (this.body.isEmpty()) { + return variables.checkExpressions(this.expressions, symbolTable).isDefined(); + } + + Supplier>> factsSupplier = () -> facts.stream(scope); + Stream>> stream = this.apply(factsSupplier, origin, symbolTable); + + Iterator>> it = stream.iterator(); + + if (!it.hasNext()) { + return false; + } + + Either> next = it.next(); + if (next.isRight()) { + return true; + } else { + throw next.getLeft(); + } + } + + // verifies that the expressions return true for every matching set of facts + public boolean checkMatchAll(final FactSet facts, TrustedOrigins scope, SymbolTable symbolTable) + throws Error { + MatchedVariables variables = variablesSet(); + + if (this.body.isEmpty()) { + return variables.checkExpressions(this.expressions, symbolTable).isDefined(); + } + + Supplier>> factsSupplier = () -> facts.stream(scope); + Combinator combinator = new Combinator(variables, this.body, factsSupplier, symbolTable); + boolean found = false; + + for (Combinator it = combinator; it.hasNext(); ) { + Tuple2> t = it.next(); + Map generatedVariables = t._2; + found = true; + + TemporarySymbolTable temporarySymbols = new TemporarySymbolTable(symbolTable); + for (Expression e : this.expressions) { + + Term term = e.evaluate(generatedVariables, temporarySymbols); + if (term instanceof Term.Bool) { + Term.Bool b = (Term.Bool) term; + if (!b.value()) { + return false; + } + // continue evaluating if true + } else { + throw new Error.InvalidType(); + } } - - ArrayList expressions = new ArrayList<>(); - for (Schema.ExpressionV2 expression: rule.getExpressionsList()) { - Either res = Expression.deserializeV2(expression); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - expressions.add(res.get()); - } + } + return found; + } + + public Rule( + final Predicate head, final List body, final List expressions) { + this.head = head; + this.body = body; + this.expressions = expressions; + this.scopes = new ArrayList<>(); + } + + public Rule( + final Predicate head, + final List body, + final List expressions, + final List scopes) { + this.head = head; + this.body = body; + this.expressions = expressions; + this.scopes = scopes; + } + + public Schema.RuleV2 serialize() { + Schema.RuleV2.Builder b = Schema.RuleV2.newBuilder().setHead(this.head.serialize()); + + for (int i = 0; i < this.body.size(); i++) { + b.addBody(this.body.get(i).serialize()); + } + + for (int i = 0; i < this.expressions.size(); i++) { + b.addExpressions(this.expressions.get(i).serialize()); + } + + for (Scope scope : this.scopes) { + b.addScope(scope.serialize()); + } + + return b.build(); + } + + public static Either deserializeV2(Schema.RuleV2 rule) { + ArrayList body = new ArrayList<>(); + for (Schema.PredicateV2 predicate : rule.getBodyList()) { + Either res = Predicate.deserializeV2(predicate); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + body.add(res.get()); } - - ArrayList scopes = new ArrayList<>(); - for (Schema.Scope scope: rule.getScopeList()) { - Either res = Scope.deserialize(scope); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - scopes.add(res.get()); - } + } + + ArrayList expressions = new ArrayList<>(); + for (Schema.ExpressionV2 expression : rule.getExpressionsList()) { + Either res = Expression.deserializeV2(expression); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + expressions.add(res.get()); } - - Either res = Predicate.deserializeV2(rule.getHead()); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); + } + + ArrayList scopes = new ArrayList<>(); + for (Schema.Scope scope : rule.getScopeList()) { + Either res = Scope.deserialize(scope); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); } else { - return Right(new Rule(res.get(), body, expressions, scopes)); + scopes.add(res.get()); } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Rule rule = (Rule) o; - - if (!Objects.equals(head, rule.head)) return false; - if (!Objects.equals(body, rule.body)) return false; - if (!Objects.equals(expressions, rule.expressions)) return false; - return Objects.equals(scopes, rule.scopes); - } - - @Override - public int hashCode() { - int result = head != null ? head.hashCode() : 0; - result = 31 * result + (body != null ? body.hashCode() : 0); - result = 31 * result + (expressions != null ? expressions.hashCode() : 0); - result = 31 * result + (scopes != null ? scopes.hashCode() : 0); - return result; - } - - @Override - public String toString() { - return "Rule{" + - "head=" + head + - ", body=" + body + - ", expressions=" + expressions + - ", scopes=" + scopes + - '}'; - } + } + + Either res = Predicate.deserializeV2(rule.getHead()); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + return Right(new Rule(res.get(), body, expressions, scopes)); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Rule rule = (Rule) o; + + if (!Objects.equals(head, rule.head)) { + return false; + } + if (!Objects.equals(body, rule.body)) { + return false; + } + if (!Objects.equals(expressions, rule.expressions)) { + return false; + } + return Objects.equals(scopes, rule.scopes); + } + + @Override + public int hashCode() { + int result = head != null ? head.hashCode() : 0; + result = 31 * result + (body != null ? body.hashCode() : 0); + result = 31 * result + (expressions != null ? expressions.hashCode() : 0); + result = 31 * result + (scopes != null ? scopes.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "Rule{" + + "head=" + + head + + ", body=" + + body + + ", expressions=" + + expressions + + ", scopes=" + + scopes + + '}'; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/RuleSet.java b/src/main/java/org/biscuitsec/biscuit/datalog/RuleSet.java index e7643359..b2430522 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/RuleSet.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/RuleSet.java @@ -1,45 +1,47 @@ package org.biscuitsec.biscuit.datalog; import io.vavr.Tuple2; - -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.stream.Stream; -public class RuleSet { - public final HashMap>> rules; +public final class RuleSet { + private final HashMap>> rules; - public RuleSet() { - rules = new HashMap<>(); - } + public RuleSet() { + rules = new HashMap<>(); + } - public void add(Long origin, TrustedOrigins scope, Rule rule) { - if (!rules.containsKey(scope)) { - rules.put(scope, List.of(new Tuple2<>(origin, rule))); - } else { - rules.get(scope).add(new Tuple2<>(origin, rule)); - } + public void add(Long origin, TrustedOrigins scope, Rule rule) { + if (!rules.containsKey(scope)) { + rules.put(scope, List.of(new Tuple2<>(origin, rule))); + } else { + rules.get(scope).add(new Tuple2<>(origin, rule)); } + } - public RuleSet clone() { - RuleSet newRules = new RuleSet(); + public RuleSet clone() { + RuleSet newRules = new RuleSet(); - for (Map.Entry>> entry : this.rules.entrySet()) { - List> l = new ArrayList<>(entry.getValue()); - newRules.rules.put(entry.getKey(), l); - } - - return newRules; + for (Map.Entry>> entry : this.rules.entrySet()) { + List> l = new ArrayList<>(entry.getValue()); + newRules.rules.put(entry.getKey(), l); } - public Stream stream() { - return rules.entrySet() - .stream() - .flatMap(entry -> entry.getValue() - .stream() - .map(t -> t._2)); - } + return newRules; + } - public void clear() { - rules.clear(); - } + public Stream stream() { + return rules.entrySet().stream().flatMap(entry -> entry.getValue().stream().map(t -> t._2)); + } + + public void clear() { + this.rules.clear(); + } + + public HashMap>> getRules() { + return this.rules; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/RunLimits.java b/src/main/java/org/biscuitsec/biscuit/datalog/RunLimits.java index 57a95833..a3b16573 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/RunLimits.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/RunLimits.java @@ -2,17 +2,28 @@ import java.time.Duration; -public class RunLimits { - public int maxFacts = 1000; - public int maxIterations = 100; - public Duration maxTime = Duration.ofMillis(5); +public final class RunLimits { + private int maxFacts = 1000; + private int maxIterations = 100; + private Duration maxTime = Duration.ofMillis(5); - public RunLimits() { - } + public RunLimits() {} - public RunLimits(int maxFacts, int maxIterations, Duration maxTime) { - this.maxFacts = maxFacts; - this.maxIterations = maxIterations; - this.maxTime = maxTime; - } + public RunLimits(int maxFacts, int maxIterations, Duration maxTime) { + this.maxFacts = maxFacts; + this.maxIterations = maxIterations; + this.maxTime = maxTime; + } + + public int getMaxFacts() { + return this.maxFacts; + } + + public int getMaxIterations() { + return this.maxIterations; + } + + public Duration getMaxTime() { + return this.maxTime; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/SchemaVersion.java b/src/main/java/org/biscuitsec/biscuit/datalog/SchemaVersion.java index 075fd29a..757c8b8b 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/SchemaVersion.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/SchemaVersion.java @@ -1,102 +1,103 @@ package org.biscuitsec.biscuit.datalog; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + +import io.vavr.control.Either; +import java.util.List; import org.biscuitsec.biscuit.datalog.expressions.Expression; import org.biscuitsec.biscuit.datalog.expressions.Op; import org.biscuitsec.biscuit.error.Error; -import io.vavr.control.Either; import org.biscuitsec.biscuit.token.format.SerializedBiscuit; -import java.util.List; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; - -public class SchemaVersion { - private boolean containsScopes; - private boolean containsCheckAll; - private boolean containsV4; +public final class SchemaVersion { + private boolean containsScopes; + private boolean containsCheckAll; + private boolean containsV4; - public SchemaVersion(List facts, List rules, List checks, List scopes) { - containsScopes = !scopes.isEmpty(); + public SchemaVersion(List facts, List rules, List checks, List scopes) { + containsScopes = !scopes.isEmpty(); - if (!containsScopes) { - for (Rule r : rules) { - if (!r.scopes().isEmpty()) { - containsScopes = true; - break; - } - } - } - if (!containsScopes) { - for (Check check : checks) { - for (Rule query : check.queries()) { - if (!query.scopes().isEmpty()) { - containsScopes = true; - break; - } - } - } + if (!containsScopes) { + for (Rule r : rules) { + if (!r.scopes().isEmpty()) { + containsScopes = true; + break; } - - containsCheckAll = false; - for (Check check : checks) { - if (check.kind() == Check.Kind.All) { - containsCheckAll = true; - break; - } + } + } + if (!containsScopes) { + for (Check check : checks) { + for (Rule query : check.queries()) { + if (!query.scopes().isEmpty()) { + containsScopes = true; + break; + } } + } + } - containsV4 = false; - for (Check check : checks) { - for (Rule query : check.queries()) { - if (containsV4Ops(query.expressions())) { - containsV4 = true; - break; - } - } - } + containsCheckAll = false; + for (Check check : checks) { + if (check.kind() == Check.Kind.ALL) { + containsCheckAll = true; + break; + } } - public int version() { - if (containsScopes || containsV4 || containsCheckAll) { - return 4; - } else { - return SerializedBiscuit.MIN_SCHEMA_VERSION; + containsV4 = false; + for (Check check : checks) { + for (Rule query : check.queries()) { + if (containsV4Ops(query.expressions())) { + containsV4 = true; + break; } + } } + } - public Either checkCompatibility(int version) { - if (version < 4) { - if (containsScopes) { - return Left(new Error.FormatError.DeserializationError("v3 blocks must not have scopes")); - } - if (containsV4) { - return Left(new Error.FormatError.DeserializationError("v3 blocks must not have v4 operators (bitwise operators or !=")); - } - if (containsCheckAll) { - return Left(new Error.FormatError.DeserializationError("v3 blocks must not use check all")); - } - } + public int version() { + if (containsScopes || containsV4 || containsCheckAll) { + return 4; + } else { + return SerializedBiscuit.MIN_SCHEMA_VERSION; + } + } - return Right(null); + public Either checkCompatibility(int version) { + if (version < 4) { + if (containsScopes) { + return Left(new Error.FormatError.DeserializationError("v3 blocks must not have scopes")); + } + if (containsV4) { + return Left( + new Error.FormatError.DeserializationError( + "v3 blocks must not have v4 operators (bitwise operators or !=")); + } + if (containsCheckAll) { + return Left(new Error.FormatError.DeserializationError("v3 blocks must not use check all")); + } } - public static boolean containsV4Ops(List expressions) { - for (Expression e : expressions) { - for (Op op : e.getOps()) { - if (op instanceof Op.Binary) { - Op.Binary b = (Op.Binary) op; - switch (b.getOp()) { - case BitwiseAnd: - case BitwiseOr: - case BitwiseXor: - case NotEqual: - return true; - } - } - } + return Right(null); + } + + public static boolean containsV4Ops(List expressions) { + for (Expression e : expressions) { + for (Op op : e.getOps()) { + if (op instanceof Op.Binary) { + Op.Binary b = (Op.Binary) op; + switch (b.getOp()) { + case BitwiseAnd: + case BitwiseOr: + case BitwiseXor: + case NotEqual: + return true; + default: + } } - return false; + } } - + return false; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Scope.java b/src/main/java/org/biscuitsec/biscuit/datalog/Scope.java index 7aede2fc..aee6ea13 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Scope.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Scope.java @@ -1,104 +1,110 @@ package org.biscuitsec.biscuit.datalog; -import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.error.Error; -import io.vavr.control.Either; - - import static io.vavr.API.Left; import static io.vavr.API.Right; -public class Scope { - public enum Kind { - Authority, - Previous, - PublicKey - } - - Kind kind; - long publicKey; - - private Scope(Kind kind, long publicKey) { - this.kind = kind; - this.publicKey = publicKey; - } - - public static Scope authority() { - return new Scope(Kind.Authority, 0); - } - - public static Scope previous() { - return new Scope(Kind.Previous, 0); - } +import biscuit.format.schema.Schema; +import io.vavr.control.Either; +import org.biscuitsec.biscuit.error.Error; - public static Scope publicKey(long publicKey) { - return new Scope(Kind.PublicKey, publicKey); +public final class Scope { + public enum Kind { + Authority, + Previous, + PublicKey + } + + private Kind kind; + private long publicKey; + + private Scope(Kind kind, long publicKey) { + this.kind = kind; + this.publicKey = publicKey; + } + + public static Scope authority() { + return new Scope(Kind.Authority, 0); + } + + public static Scope previous() { + return new Scope(Kind.Previous, 0); + } + + public static Scope publicKey(long publicKey) { + return new Scope(Kind.PublicKey, publicKey); + } + + public Kind kind() { + return kind; + } + + public long getPublicKey() { + return publicKey; + } + + public Schema.Scope serialize() { + Schema.Scope.Builder b = Schema.Scope.newBuilder(); + + switch (this.kind) { + case Authority: + b.setScopeType(Schema.Scope.ScopeType.Authority); + break; + case Previous: + b.setScopeType(Schema.Scope.ScopeType.Previous); + break; + case PublicKey: + b.setPublicKey(this.publicKey); + break; + default: } - public Kind kind() { - return kind; - } + return b.build(); + } - public long publicKey() { - return publicKey; + public static Either deserialize(Schema.Scope scope) { + if (scope.hasPublicKey()) { + long publicKey = scope.getPublicKey(); + return Right(Scope.publicKey(publicKey)); } - - public Schema.Scope serialize() { - Schema.Scope.Builder b = Schema.Scope.newBuilder(); - - switch (this.kind) { - case Authority: - b.setScopeType(Schema.Scope.ScopeType.Authority); - break; - case Previous: - b.setScopeType(Schema.Scope.ScopeType.Previous); - break; - case PublicKey: - b.setPublicKey(this.publicKey); - } - - return b.build(); + if (scope.hasScopeType()) { + switch (scope.getScopeType()) { + case Authority: + return Right(Scope.authority()); + case Previous: + return Right(Scope.previous()); + default: + return Left(new Error.FormatError.DeserializationError("invalid Scope")); + } } - - static public Either deserialize(Schema.Scope scope) { - if (scope.hasPublicKey()) { - long publicKey = scope.getPublicKey(); - return Right(Scope.publicKey(publicKey)); - } - if (scope.hasScopeType()) { - switch (scope.getScopeType()) { - case Authority: - return Right(Scope.authority()); - case Previous: - return Right(Scope.previous()); - } - } - return Left(new Error.FormatError.DeserializationError("invalid Scope")); + return Left(new Error.FormatError.DeserializationError("invalid Scope")); + } + + @Override + public String toString() { + return "Scope{" + "kind=" + kind + ", publicKey=" + publicKey + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return "Scope{" + - "kind=" + kind + - ", publicKey=" + publicKey + - '}'; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Scope scope = (Scope) o; - - if (publicKey != scope.publicKey) return false; - return kind == scope.kind; - } + Scope scope = (Scope) o; - @Override - public int hashCode() { - int result = kind != null ? kind.hashCode() : 0; - result = 31 * result + (int) (publicKey ^ (publicKey >>> 32)); - return result; + if (publicKey != scope.publicKey) { + return false; } + return kind == scope.kind; + } + + @Override + public int hashCode() { + int result = kind != null ? kind.hashCode() : 0; + result = 31 * result + (int) (publicKey ^ (publicKey >>> 32)); + return result; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java b/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java index d5ead0a9..47486a7a 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/SymbolTable.java @@ -1,310 +1,350 @@ package org.biscuitsec.biscuit.datalog; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.expressions.Expression; -import org.biscuitsec.biscuit.token.builder.Utils; import io.vavr.control.Option; - import java.io.Serializable; import java.time.Instant; import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.token.builder.Utils; public final class SymbolTable implements Serializable { - public final static short DEFAULT_SYMBOLS_OFFSET = 1024; - - private final DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ISO_INSTANT; - - public String fromEpochIsoDate(long epochSec) { - return Instant.ofEpochSecond(epochSec).atOffset(ZoneOffset.ofTotalSeconds(0)).format(dateTimeFormatter); + public static final short DEFAULT_SYMBOLS_OFFSET = 1024; + + private final DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ISO_INSTANT; + + public String fromEpochIsoDate(long epochSec) { + return Instant.ofEpochSecond(epochSec) + .atOffset(ZoneOffset.ofTotalSeconds(0)) + .format(dateTimeFormatter); + } + + /** + * According to the + * specification, We need two symbols tables: * one for the defaults symbols indexed from 0 et + * 1023 in defaultSymbols list * one for the usages symbols indexed from 1024 in + * symbols list + */ + public static final List DEFAULT_SYMBOLS = + List.of( + "read", + "write", + "resource", + "operation", + "right", + "time", + "role", + "owner", + "tenant", + "namespace", + "user", + "team", + "service", + "admin", + "email", + "group", + "member", + "ip_address", + "client", + "client_ip", + "domain", + "path", + "version", + "cluster", + "node", + "hostname", + "nonce", + "query"); + + private final List symbols; + private final List publicKeys; + + public long insert(final String symbol) { + int index = this.DEFAULT_SYMBOLS.indexOf(symbol); + if (index == -1) { + index = this.symbols.indexOf(symbol); + if (index == -1) { + this.symbols.add(symbol); + return this.symbols.size() - 1 + DEFAULT_SYMBOLS_OFFSET; + } else { + return index + DEFAULT_SYMBOLS_OFFSET; + } + } else { + return index; } - - /** - * According to the specification, - * We need two symbols tables: - * * one for the defaults symbols indexed from 0 et 1023 in defaultSymbols list - * * one for the usages symbols indexed from 1024 in symbols list - */ - public final static List defaultSymbols = List.of( - "read", - "write", - "resource", - "operation", - "right", - "time", - "role", - "owner", - "tenant", - "namespace", - "user", - "team", - "service", - "admin", - "email", - "group", - "member", - "ip_address", - "client", - "client_ip", - "domain", - "path", - "version", - "cluster", - "node", - "hostname", - "nonce", - "query" - ); - public final List symbols; - private final List publicKeys; - - public long insert(final String symbol) { - int index = this.defaultSymbols.indexOf(symbol); - if (index == -1) { - index = this.symbols.indexOf(symbol); - if (index == -1) { - this.symbols.add(symbol); - return this.symbols.size() - 1 + DEFAULT_SYMBOLS_OFFSET; - } else { - return index + DEFAULT_SYMBOLS_OFFSET; - } - } else { - return index; - } - } - - public int currentOffset() { - return this.symbols.size(); - } - public int currentPublicKeyOffset() { - return this.publicKeys.size(); - } - - public List publicKeys() { - return publicKeys; - } - - public long insert(final PublicKey publicKey) { - int index = this.publicKeys.indexOf(publicKey); - if (index == -1) { - this.publicKeys.add(publicKey); - return this.publicKeys.size() - 1; - } else { - return index; - } - } - - public Term add(final String symbol) { - return new Term.Str(this.insert(symbol)); - } - - public Option get(final String symbol) { - // looking for symbol in default symbols - long index = this.defaultSymbols.indexOf(symbol); - if (index == -1) { - // looking for symbol in usages defined symbols - index = this.symbols.indexOf(symbol); - if (index == -1) { - return Option.none(); - } else { - return Option.some(index + DEFAULT_SYMBOLS_OFFSET); - } - } else { - return Option.some(index); - } - } - - public Option get_s(int i) { - if (i >= 0 && i < this.defaultSymbols.size() && i < DEFAULT_SYMBOLS_OFFSET) { - return Option.some(this.defaultSymbols.get(i)); - } else if (i >= DEFAULT_SYMBOLS_OFFSET && i < this.symbols.size() + DEFAULT_SYMBOLS_OFFSET) { - return Option.some(this.symbols.get(i - DEFAULT_SYMBOLS_OFFSET)); - } else { - return Option.none(); - } - } - - public Option get_pk(int i) { - if (i >= 0 && i < this.publicKeys.size()) { - return Option.some(this.publicKeys.get(i)); - } else { - return Option.none(); - } + } + + public long insert(final PublicKey publicKey) { + int index = this.publicKeys.indexOf(publicKey); + if (index == -1) { + this.publicKeys.add(publicKey); + return this.publicKeys.size() - 1; + } else { + return index; } - - public String print_rule(final Rule r) { - String res = this.print_predicate(r.head()); - res += " <- " + this.print_rule_body(r); - - return res; + } + + public int currentOffset() { + return this.symbols.size(); + } + + public int currentPublicKeyOffset() { + return this.publicKeys.size(); + } + + public List getPublicKeys() { + return publicKeys; + } + + public Term add(final String symbol) { + return new Term.Str(this.insert(symbol)); + } + + public Option get(final String symbol) { + // looking for symbol in default symbols + long index = this.DEFAULT_SYMBOLS.indexOf(symbol); + if (index == -1) { + // looking for symbol in usages defined symbols + index = this.symbols.indexOf(symbol); + if (index == -1) { + return Option.none(); + } else { + return Option.some(index + DEFAULT_SYMBOLS_OFFSET); + } + } else { + return Option.some(index); } - - public String print_rule_body(final Rule r) { - final List preds = r.body().stream().map((p) -> this.print_predicate(p)).collect(Collectors.toList()); - final List expressions = r.expressions().stream().map((c) -> this.print_expression(c)).collect(Collectors.toList()); - - String res = String.join(", ", preds); - if (!expressions.isEmpty()) { - if (!preds.isEmpty()) { - res += ", "; - } - res += String.join(", ", expressions); - } - - if(!r.scopes().isEmpty()) { - res += " trusting "; - final List scopes = r.scopes().stream().map((s) -> this.print_scope(s)).collect(Collectors.toList()); - res += String.join(", ", scopes); - } - return res; + } + + public Option getSymbol(int i) { + if (i >= 0 && i < this.DEFAULT_SYMBOLS.size() && i < DEFAULT_SYMBOLS_OFFSET) { + return Option.some(this.DEFAULT_SYMBOLS.get(i)); + } else if (i >= DEFAULT_SYMBOLS_OFFSET && i < this.symbols.size() + DEFAULT_SYMBOLS_OFFSET) { + return Option.some(this.symbols.get(i - DEFAULT_SYMBOLS_OFFSET)); + } else { + return Option.none(); } + } - public String print_expression(final Expression e) { - return e.print(this).get(); + public Option getPublicKey(int i) { + if (i >= 0 && i < this.publicKeys.size()) { + return Option.some(this.publicKeys.get(i)); + } else { + return Option.none(); } - - public String print_scope(final Scope scope) { - switch(scope.kind) { - case Authority: - return "authority"; - case Previous: - return "previous"; - case PublicKey: - Option pk = this.get_pk((int) scope.publicKey); - if(pk.isDefined()) { - return pk.get().toString(); - } - } - return "<"+ scope.publicKey+"?>"; + } + + public String formatRule(final Rule r) { + String res = this.formatPredicate(r.head()); + res += " <- " + this.formatRuleBody(r); + + return res; + } + + public String formatRuleBody(final Rule r) { + final List preds = + r.body().stream().map((p) -> this.formatPredicate(p)).collect(Collectors.toList()); + final List expressions = + r.expressions().stream().map((c) -> this.formatExpression(c)).collect(Collectors.toList()); + + String res = String.join(", ", preds); + if (!expressions.isEmpty()) { + if (!preds.isEmpty()) { + res += ", "; + } + res += String.join(", ", expressions); } - public String print_predicate(final Predicate p) { - List ids = p.terms().stream().map((t) -> { - return this.print_term(t); - }).collect(Collectors.toList()); - return Optional.ofNullable(this.print_symbol((int) p.name())).orElse("") + "(" + String.join(", ", ids) + ")"; + if (!r.scopes().isEmpty()) { + res += " trusting "; + final List scopes = + r.scopes().stream().map((s) -> this.formatScope(s)).collect(Collectors.toList()); + res += String.join(", ", scopes); } - - public String print_term(final Term i) { - if (i instanceof Term.Variable) { - return "$" + this.print_symbol((int) ((Term.Variable) i).value()); - } else if(i instanceof Term.Bool) { - return i.toString(); - } else if (i instanceof Term.Date) { - return fromEpochIsoDate(((Term.Date) i).value()); - } else if (i instanceof Term.Integer) { - return "" + ((Term.Integer) i).value(); - } else if (i instanceof Term.Str) { - return "\"" + this.print_symbol((int) ((Term.Str) i).value()) + "\""; - } else if (i instanceof Term.Bytes) { - return "hex:" + Utils.byteArrayToHexString(((Term.Bytes) i).value()).toLowerCase(); - } else if (i instanceof Term.Set) { - final List values = ((Term.Set) i).value().stream().map((v) -> this.print_term(v)).collect(Collectors.toList()); - return "[" + String.join(", ", values) + "]"; + return res; + } + + public String formatExpression(final Expression e) { + return e.print(this).get(); + } + + public String formatScope(final Scope scope) { + switch (scope.kind()) { + case Authority: + return "authority"; + case Previous: + return "previous"; + case PublicKey: + Option pk = this.getPublicKey((int) scope.getPublicKey()); + if (pk.isDefined()) { + return pk.get().toString(); } else { - return "???"; + return "<" + scope.getPublicKey() + "?>"; } + default: + return "<" + scope.getPublicKey() + "?>"; } - - public String print_fact(final Fact f) { - return this.print_predicate(f.predicate()); + } + + public String formatPredicate(final Predicate p) { + List ids = + p.terms().stream() + .map( + (t) -> { + return this.formatTerm(t); + }) + .collect(Collectors.toList()); + return Optional.ofNullable(this.formatSymbol((int) p.name())).orElse("") + + "(" + + String.join(", ", ids) + + ")"; + } + + public String formatTerm(final Term i) { + if (i instanceof Term.Variable) { + return "$" + this.formatSymbol((int) ((Term.Variable) i).value()); + } else if (i instanceof Term.Bool) { + return i.toString(); + } else if (i instanceof Term.Date) { + return fromEpochIsoDate(((Term.Date) i).value()); + } else if (i instanceof Term.Integer) { + return "" + ((Term.Integer) i).value(); + } else if (i instanceof Term.Str) { + return "\"" + this.formatSymbol((int) ((Term.Str) i).value()) + "\""; + } else if (i instanceof Term.Bytes) { + return "hex:" + Utils.byteArrayToHexString(((Term.Bytes) i).value()).toLowerCase(); + } else if (i instanceof Term.Set) { + final List values = + ((Term.Set) i) + .value().stream().map((v) -> this.formatTerm(v)).collect(Collectors.toList()); + return "[" + String.join(", ", values) + "]"; + } else { + return "???"; } - - public String print_check(final Check c) { - String prefix; - switch (c.kind()) { - case One: - prefix = "check if "; - break; - case All: - prefix = "check all "; - break; - default: - prefix = "check if "; - break; - } - final List queries = c.queries().stream().map((q) -> this.print_rule_body(q)).collect(Collectors.toList()); - return prefix + String.join(" or ", queries); + } + + public String formatFact(final Fact f) { + return this.formatPredicate(f.predicate()); + } + + public String formatCheck(final Check c) { + String prefix; + switch (c.kind()) { + case ONE: + prefix = "check if "; + break; + case ALL: + prefix = "check all "; + break; + default: + prefix = "check if "; + break; } - - public String print_world(final World w) { - final List facts = w.facts().stream().map((f) -> this.print_fact(f)).collect(Collectors.toList()); - final List rules = w.rules().stream().map((r) -> this.print_rule(r)).collect(Collectors.toList()); - - StringBuilder b = new StringBuilder(); - b.append("World {\n\tfacts: [\n\t\t"); - b.append(String.join(",\n\t\t", facts)); - b.append("\n\t],\n\trules: [\n\t\t"); - b.append(String.join(",\n\t\t", rules)); - b.append("\n\t]\n}"); - - return b.toString(); - } - - public String print_symbol(int i) { - return get_s(i).getOrElse("<" + i + "?>"); - } - - public SymbolTable() { - this.symbols = new ArrayList<>(); - this.publicKeys = new ArrayList<>(); - } - - public SymbolTable(SymbolTable s) { - this.symbols = new ArrayList<>(); - symbols.addAll(s.symbols); - this.publicKeys = new ArrayList<>(); - publicKeys.addAll(s.publicKeys); - } - - public SymbolTable(List symbols) { - this.symbols = new ArrayList<>(symbols); - this.publicKeys = new ArrayList<>(); - } - - public SymbolTable(List symbols, List publicKeys) { - this.symbols = new ArrayList<>(); - this.symbols.addAll(symbols); - this.publicKeys = new ArrayList<>(); - this.publicKeys.addAll(publicKeys); + final List queries = + c.queries().stream().map((q) -> this.formatRuleBody(q)).collect(Collectors.toList()); + return prefix + String.join(" or ", queries); + } + + public String formatWorld(final World w) { + final List facts = + w.getFacts().stream().map((f) -> this.formatFact(f)).collect(Collectors.toList()); + final List rules = + w.getRules().stream().map((r) -> this.formatRule(r)).collect(Collectors.toList()); + + StringBuilder b = new StringBuilder(); + b.append("World {\n\tfacts: [\n\t\t"); + b.append(String.join(",\n\t\t", facts)); + b.append("\n\t],\n\trules: [\n\t\t"); + b.append(String.join(",\n\t\t", rules)); + b.append("\n\t]\n}"); + + return b.toString(); + } + + public String formatSymbol(int i) { + return getSymbol(i).getOrElse("<" + i + "?>"); + } + + public SymbolTable() { + this.symbols = new ArrayList<>(); + this.publicKeys = new ArrayList<>(); + } + + public SymbolTable(SymbolTable s) { + this.symbols = new ArrayList<>(); + symbols.addAll(s.symbols); + this.publicKeys = new ArrayList<>(); + publicKeys.addAll(s.publicKeys); + } + + public SymbolTable(List symbols) { + this.symbols = new ArrayList<>(symbols); + this.publicKeys = new ArrayList<>(); + } + + public SymbolTable(SymbolTable sourceSymbolTable, List publicKeys) { + this(sourceSymbolTable.symbols, publicKeys); + } + + public SymbolTable(List symbols, List publicKeys) { + this.symbols = new ArrayList<>(); + this.symbols.addAll(symbols); + this.publicKeys = new ArrayList<>(); + this.publicKeys.addAll(publicKeys); + } + + public List getAllSymbols() { + ArrayList allSymbols = new ArrayList<>(); + allSymbols.addAll(DEFAULT_SYMBOLS); + allSymbols.addAll(symbols); + return allSymbols; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public List getAllSymbols() { - ArrayList allSymbols = new ArrayList<>(); - allSymbols.addAll(defaultSymbols); - allSymbols.addAll(symbols); - return allSymbols; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + SymbolTable that = (SymbolTable) o; - SymbolTable that = (SymbolTable) o; - - if (!dateTimeFormatter.equals(that.dateTimeFormatter)) return false; - if (!symbols.equals(that.symbols)) return false; - return publicKeys.equals(that.publicKeys); - } - - @Override - public int hashCode() { - int result = dateTimeFormatter.hashCode(); - result = 31 * result + symbols.hashCode(); - result = 31 * result + publicKeys.hashCode(); - return result; + if (!dateTimeFormatter.equals(that.dateTimeFormatter)) { + return false; } - - @Override - public String toString() { - return "SymbolTable{" + - "symbols=" + symbols + - ", publicKeys=" + publicKeys + - '}'; + if (!symbols.equals(that.symbols)) { + return false; } + return publicKeys.equals(that.publicKeys); + } + + @Override + public int hashCode() { + int result = dateTimeFormatter.hashCode(); + result = 31 * result + symbols.hashCode(); + result = 31 * result + publicKeys.hashCode(); + return result; + } + + @Override + public String toString() { + return "SymbolTable{" + "symbols=" + symbols + ", publicKeys=" + publicKeys + '}'; + } + + public List symbols() { + return Collections.unmodifiableList(symbols); + } + + public boolean disjoint(final SymbolTable other) { + return Collections.disjoint(this.symbols, other.symbols); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/TemporarySymbolTable.java b/src/main/java/org/biscuitsec/biscuit/datalog/TemporarySymbolTable.java index c64b5a26..14f191b5 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/TemporarySymbolTable.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/TemporarySymbolTable.java @@ -1,46 +1,45 @@ package org.biscuitsec.biscuit.datalog; -import io.vavr.control.Option; +import static org.biscuitsec.biscuit.datalog.SymbolTable.DEFAULT_SYMBOLS_OFFSET; +import io.vavr.control.Option; import java.util.ArrayList; import java.util.List; -import static org.biscuitsec.biscuit.datalog.SymbolTable.DEFAULT_SYMBOLS_OFFSET; - -public class TemporarySymbolTable { - SymbolTable base; - int offset; - List symbols; - - public TemporarySymbolTable(SymbolTable base) { - this.offset = DEFAULT_SYMBOLS_OFFSET + base.currentOffset(); - this.base = base; - this.symbols = new ArrayList<>(); +public final class TemporarySymbolTable { + private SymbolTable base; + private int offset; + private List symbols; + + public TemporarySymbolTable(SymbolTable base) { + this.offset = DEFAULT_SYMBOLS_OFFSET + base.currentOffset(); + this.base = base; + this.symbols = new ArrayList<>(); + } + + public Option getSymbol(int i) { + if (i >= this.offset) { + if (i - this.offset < this.symbols.size()) { + return Option.some(this.symbols.get(i - this.offset)); + } else { + return Option.none(); + } + } else { + return this.base.getSymbol(i); } + } - public Option get_s(int i) { - if (i >= this.offset) { - if (i - this.offset < this.symbols.size()) { - return Option.some(this.symbols.get(i - this.offset)); - } else { - return Option.none(); - } - } else { - return this.base.get_s(i); - } + public long insert(final String symbol) { + Option opt = this.base.get(symbol); + if (opt.isDefined()) { + return opt.get(); } - public long insert(final String symbol) { - Option opt = this.base.get(symbol); - if (opt.isDefined()) { - return opt.get(); - } - - int index = this.symbols.indexOf(symbol); - if (index != -1) { - return (long) (this.offset + index); - } - this.symbols.add(symbol); - return this.symbols.size() - 1 + this.offset; + int index = this.symbols.indexOf(symbol); + if (index != -1) { + return (long) (this.offset + index); } + this.symbols.add(symbol); + return this.symbols.size() - 1 + this.offset; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/Term.java b/src/main/java/org/biscuitsec/biscuit/datalog/Term.java index d90fe6ef..375359f7 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/Term.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/Term.java @@ -1,477 +1,504 @@ package org.biscuitsec.biscuit.datalog; -import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.error.Error; -import com.google.protobuf.ByteString; -import io.vavr.control.Either; - import static io.vavr.API.Left; import static io.vavr.API.Right; +import biscuit.format.schema.Schema; +import com.google.protobuf.ByteString; +import io.vavr.control.Either; import java.io.Serializable; import java.util.Arrays; import java.util.HashSet; - +import org.biscuitsec.biscuit.error.Error; public abstract class Term implements Serializable { - public abstract boolean match(final Term other); - public abstract Schema.TermV2 serialize(); - - static public Either deserialize_enumV2(Schema.TermV2 term) { - if(term.hasDate()) { - return Date.deserializeV2(term); - } else if(term.hasInteger()) { - return Integer.deserializeV2(term); - } else if(term.hasString()) { - return Str.deserializeV2(term); - } else if(term.hasBytes()) { - return Bytes.deserializeV2(term); - } else if(term.hasVariable()) { - return Variable.deserializeV2(term); - } else if(term.hasBool()) { - return Bool.deserializeV2(term); - } else if(term.hasSet()) { - return Set.deserializeV2(term); + public abstract boolean match(Term other); + + public abstract Schema.TermV2 serialize(); + + public static Either deserializeEnumV2(Schema.TermV2 term) { + if (term.hasDate()) { + return Date.deserializeV2(term); + } else if (term.hasInteger()) { + return Integer.deserializeV2(term); + } else if (term.hasString()) { + return Str.deserializeV2(term); + } else if (term.hasBytes()) { + return Bytes.deserializeV2(term); + } else if (term.hasVariable()) { + return Variable.deserializeV2(term); + } else if (term.hasBool()) { + return Bool.deserializeV2(term); + } else if (term.hasSet()) { + return Set.deserializeV2(term); + } else { + return Left(new Error.FormatError.DeserializationError("invalid Term kind: term.getKind()")); + } + } + + public abstract org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable); + + public static final class Date extends Term implements Serializable { + private final long value; + + public long value() { + return this.value; + } + + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; } else { - return Left(new Error.FormatError.DeserializationError("invalid Term kind: term.getKind()")); + return this.equals(other); } - } + } - public abstract org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols); + public Date(final long value) { + this.value = value; + } - public final static class Date extends Term implements Serializable { - private final long value; - - public long value() { - return this.value; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } else { - return this.equals(other); - } + if (o == null || getClass() != o.getClass()) { + return false; } - public Date(final long value) { - this.value = value; - } + Date date = (Date) o; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return value == date.value; + } - Date date = (Date) o; + @Override + public int hashCode() { + return (int) (value ^ (value >>> 32)); + } - return value == date.value; - } + @Override + public String toString() { + return "@" + this.value; + } - @Override - public int hashCode() { - return (int) (value ^ (value >>> 32)); - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setDate(this.value).build(); + } - @Override - public String toString() { - return "@" + this.value; + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasDate()) { + return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected date")); + } else { + return Right(new Date(term.getDate())); } + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setDate(this.value).build(); - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Date(this.value); + } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasDate()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected date")); - } else { - return Right(new Date(term.getDate())); - } - } + public static final class Integer extends Term implements Serializable { + private final long value; + + public long value() { + return this.value; + } - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Date(this.value); + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; } - } + if (other instanceof Integer) { + return this.value == ((Integer) other).value; + } + return false; + } - public final static class Integer extends Term implements Serializable { - private final long value; + public Integer(final long value) { + this.value = value; + } - public long value() { - return this.value; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } - if (other instanceof Integer) { - return this.value == ((Integer) other).value; - } - return false; + if (o == null || getClass() != o.getClass()) { + return false; } - public Integer(final long value) { - this.value = value; - } + Integer integer = (Integer) o; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return value == integer.value; + } - Integer integer = (Integer) o; + @Override + public int hashCode() { + return (int) (value ^ (value >>> 32)); + } - return value == integer.value; - } + @Override + public String toString() { + return "" + this.value; + } - @Override - public int hashCode() { - return (int) (value ^ (value >>> 32)); - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setInteger(this.value).build(); + } - @Override - public String toString() { - return "" + this.value; + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasInteger()) { + return Left( + new Error.FormatError.DeserializationError("invalid Term kind, expected integer")); + } else { + return Right(new Integer(term.getInteger())); } + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setInteger(this.value).build(); - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Integer(this.value); + } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasInteger()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected integer")); - } else { - return Right(new Integer(term.getInteger())); - } - } + public static final class Bytes extends Term implements Serializable { + private final byte[] value; - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Integer(this.value); + public byte[] value() { + return this.value; + } + + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; + } + if (other instanceof Bytes) { + return this.value.equals(((Bytes) other).value); } - } + return false; + } - public final static class Bytes extends Term implements Serializable { - private final byte[] value; + public Bytes(final byte[] value) { + this.value = value; + } - public byte[] value() { - return this.value; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } - if (other instanceof Bytes) { - return this.value.equals(((Bytes) other).value); - } - return false; + if (o == null || getClass() != o.getClass()) { + return false; } - public Bytes(final byte[] value) { - this.value = value; - } + Bytes bytes = (Bytes) o; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return Arrays.equals(value, bytes.value); + } - Bytes bytes = (Bytes) o; + @Override + public int hashCode() { + return Arrays.hashCode(value); + } - return Arrays.equals(value, bytes.value); - } + @Override + public String toString() { + return this.value.toString(); + } - @Override - public int hashCode() { - return Arrays.hashCode(value); - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setBytes(ByteString.copyFrom(this.value)).build(); + } - @Override - public String toString() { - return this.value.toString(); + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasBytes()) { + return Left( + new Error.FormatError.DeserializationError("invalid Term kind, expected byte array")); + } else { + return Right(new Bytes(term.getBytes().toByteArray())); } + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setBytes(ByteString.copyFrom(this.value)).build(); - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Bytes(this.value); + } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasBytes()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected byte array")); - } else { - return Right(new Bytes(term.getBytes().toByteArray())); - } - } + public static final class Str extends Term implements Serializable { + private final long value; - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Bytes(this.value); + public long value() { + return this.value; + } + + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; + } + if (other instanceof Str) { + return this.value == ((Str) other).value; } - } + return false; + } - public final static class Str extends Term implements Serializable { - private final long value; + public Str(final long value) { + this.value = value; + } - public long value() { - return this.value; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } - if (other instanceof Str) { - return this.value == ((Str) other).value; - } - return false; + if (o == null || getClass() != o.getClass()) { + return false; } - public Str(final long value) { - this.value = value; - } + Str s = (Str) o; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return value == s.value; + } - Str s = (Str) o; + @Override + public int hashCode() { + return (int) (value ^ (value >>> 32)); + } - return value == s.value; - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setString(this.value).build(); + } - @Override - public int hashCode() { - return (int) (value ^ (value >>> 32)); + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasString()) { + return Left( + new Error.FormatError.DeserializationError("invalid Term kind, expected string")); + } else { + return Right(new Str(term.getString())); } + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setString(this.value).build(); - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Str( + symbolTable.formatSymbol((int) this.value)); + } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasString()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected string")); - } else { - return Right(new Str(term.getString())); - } - } + public static final class Variable extends Term implements Serializable { + private final long value; - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Str(symbols.print_symbol((int) this.value)); - } - } + public long value() { + return this.value; + } - public final static class Variable extends Term implements Serializable { - private final long value; + public boolean match(final Term other) { + return true; + } - public long value() { - return this.value; - } + public Variable(final long value) { + this.value = value; + } - public boolean match(final Term other) { - return true; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public Variable(final long value) { - this.value = value; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + Variable variable = (Variable) o; - Variable variable = (Variable) o; + return value == variable.value; + } - return value == variable.value; - } + @Override + public int hashCode() { + return (int) (value ^ (value >>> 32)); + } - @Override - public int hashCode() { - return (int) (value ^ (value >>> 32)); - } + @Override + public String toString() { + return this.value + "?"; + } - @Override - public String toString() { - return this.value + "?"; - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setVariable((int) this.value).build(); + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setVariable((int) this.value).build(); + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasVariable()) { + return Left( + new Error.FormatError.DeserializationError("invalid Term kind, expected variable")); + } else { + return Right(new Variable(term.getVariable())); } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasVariable()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected variable")); - } else { - return Right(new Variable(term.getVariable())); - } - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Variable( + symbolTable.formatSymbol((int) this.value)); + } + } - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Variable(symbols.print_symbol((int) this.value)); - } - } + public static final class Bool extends Term implements Serializable { + private final boolean value; - public final static class Bool extends Term implements Serializable { - private final boolean value; + public boolean value() { + return this.value; + } - public boolean value() { - return this.value; + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } - if (other instanceof Bool) { - return this.value == ((Bool) other).value; - } - return false; + if (other instanceof Bool) { + return this.value == ((Bool) other).value; } + return false; + } + + public Bool(final boolean value) { + this.value = value; + } - public Bool(final boolean value) { - this.value = value; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + Bool bool = (Bool) o; - Bool bool = (Bool) o; + return value == bool.value; + } - return value == bool.value; - } + @Override + public int hashCode() { + return (value ? 1 : 0); + } - @Override - public int hashCode() { - return (value ? 1 : 0); - } + @Override + public String toString() { + return "" + this.value; + } - @Override - public String toString() { - return "" + this.value; - } + public Schema.TermV2 serialize() { + return Schema.TermV2.newBuilder().setBool(this.value).build(); + } - public Schema.TermV2 serialize() { - return Schema.TermV2.newBuilder() - .setBool(this.value).build(); + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasBool()) { + return Left( + new Error.FormatError.DeserializationError("invalid Term kind, expected boolean")); + } else { + return Right(new Bool(term.getBool())); } + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasBool()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected boolean")); - } else { - return Right(new Bool(term.getBool())); - } - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.token.builder.Term.Bool(this.value); + } + } - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - return new org.biscuitsec.biscuit.token.builder.Term.Bool(this.value); - } - } + public static final class Set extends Term implements Serializable { + private final HashSet value; - public final static class Set extends Term implements Serializable { - private final HashSet value; + public HashSet value() { + return this.value; + } - public HashSet value() { - return this.value; + public boolean match(final Term other) { + if (other instanceof Variable) { + return true; } - - public boolean match(final Term other) { - if (other instanceof Variable) { - return true; - } - if (other instanceof Set) { - return this.value.equals(((Set) other).value); - } - return false; - } - - public Set(final HashSet value) { - this.value = value; + if (other instanceof Set) { + return this.value.equals(((Set) other).value); } + return false; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Set set = (Set) o; + public Set(final HashSet value) { + this.value = value; + } - return value.equals(set.value); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return value.hashCode(); + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public String toString() { - return "" + - value; - } + Set set = (Set) o; - public Schema.TermV2 serialize() { - Schema.TermSet.Builder s = Schema.TermSet.newBuilder(); + return value.equals(set.value); + } - for (Term l: this.value) { - s.addSet(l.serialize()); - } + @Override + public int hashCode() { + return value.hashCode(); + } - return Schema.TermV2.newBuilder() - .setSet(s).build(); - } + @Override + public String toString() { + return "" + value; + } - static public Either deserializeV2(Schema.TermV2 term) { - if(!term.hasSet()) { - return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected set")); - } else { - java.util.HashSet values = new HashSet<>(); - Schema.TermSet s = term.getSet(); + public Schema.TermV2 serialize() { + Schema.TermSet.Builder s = Schema.TermSet.newBuilder(); - for (Schema.TermV2 l: s.getSetList()) { - Either res = Term.deserialize_enumV2(l); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - Term value = res.get(); + for (Term l : this.value) { + s.addSet(l.serialize()); + } - if(value instanceof Variable) { - return Left(new Error.FormatError.DeserializationError("sets cannot contain variables")); - } + return Schema.TermV2.newBuilder().setSet(s).build(); + } - values.add(value); - } + public static Either deserializeV2(Schema.TermV2 term) { + if (!term.hasSet()) { + return Left(new Error.FormatError.DeserializationError("invalid Term kind, expected set")); + } else { + java.util.HashSet values = new HashSet<>(); + Schema.TermSet s = term.getSet(); + + for (Schema.TermV2 l : s.getSetList()) { + Either res = Term.deserializeEnumV2(l); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + Term value = res.get(); + + if (value instanceof Variable) { + return Left( + new Error.FormatError.DeserializationError("sets cannot contain variables")); } - if(values.isEmpty()) { - return Left(new Error.FormatError.DeserializationError("invalid Set value")); - } else { - return Right(new Set(values)); - } - } - } + values.add(value); + } + } - public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbols) { - HashSet s = new HashSet<>(); + if (values.isEmpty()) { + return Left(new Error.FormatError.DeserializationError("invalid Set value")); + } else { + return Right(new Set(values)); + } + } + } - for(Term i: this.value) { - s.add(i.toTerm(symbols)); - } + public org.biscuitsec.biscuit.token.builder.Term toTerm(SymbolTable symbolTable) { + HashSet s = new HashSet<>(); - return new org.biscuitsec.biscuit.token.builder.Term.Set(s); + for (Term i : this.value) { + s.add(i.toTerm(symbolTable)); } - } + + return new org.biscuitsec.biscuit.token.builder.Term.Set(s); + } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/TrustedOrigins.java b/src/main/java/org/biscuitsec/biscuit/datalog/TrustedOrigins.java index 5e063cb6..c42ca2d9 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/TrustedOrigins.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/TrustedOrigins.java @@ -3,85 +3,90 @@ import java.util.HashMap; import java.util.List; -public class TrustedOrigins { - private final Origin inner; +public final class TrustedOrigins { + private final Origin origin; - public TrustedOrigins(int... origins) { - Origin origin = new Origin(); - for (int i : origins) { - origin.add(i); - } - inner = origin; + public TrustedOrigins(int... origins) { + Origin origin = new Origin(); + for (int i : origins) { + origin.add(i); } + this.origin = origin; + } - private TrustedOrigins() { - inner = new Origin(); - } + private TrustedOrigins() { + origin = new Origin(); + } - private TrustedOrigins(Origin inner) { - if (inner == null) { - throw new RuntimeException(); - } - this.inner = inner; + private TrustedOrigins(Origin inner) { + if (inner == null) { + throw new RuntimeException(); } + this.origin = inner; + } - public TrustedOrigins clone() { - return new TrustedOrigins(this.inner.clone()); - } + public TrustedOrigins clone() { + return new TrustedOrigins(this.origin.clone()); + } - public static TrustedOrigins defaultOrigins() { - TrustedOrigins origins = new TrustedOrigins(); - origins.inner.add(0); - origins.inner.add(Long.MAX_VALUE); - return origins; - } + public static TrustedOrigins defaultOrigins() { + TrustedOrigins origins = new TrustedOrigins(); + origins.origin.add(0); + origins.origin.add(Long.MAX_VALUE); + return origins; + } - public static TrustedOrigins fromScopes(List ruleScopes, - TrustedOrigins defaultOrigins, - long currentBlock, - HashMap> publicKeyToBlockId) { - if (ruleScopes.isEmpty()) { - TrustedOrigins origins = defaultOrigins.clone(); - origins.inner.add(currentBlock); - origins.inner.add(Long.MAX_VALUE); - return origins; - } + public static TrustedOrigins fromScopes( + List ruleScopes, + TrustedOrigins defaultOrigins, + long currentBlock, + HashMap> publicKeyToBlockId) { + if (ruleScopes.isEmpty()) { + TrustedOrigins origins = defaultOrigins.clone(); + origins.origin.add(currentBlock); + origins.origin.add(Long.MAX_VALUE); + return origins; + } - TrustedOrigins origins = new TrustedOrigins(); - origins.inner.add(currentBlock); - origins.inner.add(Long.MAX_VALUE); + TrustedOrigins origins = new TrustedOrigins(); + origins.origin.add(currentBlock); + origins.origin.add(Long.MAX_VALUE); - for (Scope scope : ruleScopes) { - switch (scope.kind()) { - case Authority: - origins.inner.add(0); - break; - case Previous: - if (currentBlock != Long.MAX_VALUE) { - for (long i = 0; i < currentBlock + 1; i++) { - origins.inner.add(i); - } - } - break; - case PublicKey: - List blockIds = publicKeyToBlockId.get(scope.publicKey()); - if (blockIds != null) { - origins.inner.inner.addAll(blockIds); - } + for (Scope scope : ruleScopes) { + switch (scope.kind()) { + case Authority: + origins.origin.add(0); + break; + case Previous: + if (currentBlock != Long.MAX_VALUE) { + for (long i = 0; i < currentBlock + 1; i++) { + origins.origin.add(i); } - } - - return origins; + } + break; + case PublicKey: + List blockIds = publicKeyToBlockId.get(scope.getPublicKey()); + if (blockIds != null) { + origins.origin.addAll(blockIds); + } + break; + default: + } } - public boolean contains(Origin factOrigin) { - return this.inner.inner.containsAll(factOrigin.inner); - } + return origins; + } - @Override - public String toString() { - return "TrustedOrigins{" + - "inner=" + inner + - '}'; - } + public boolean contains(Origin factOrigin) { + return this.origin.containsAll(factOrigin); + } + + @Override + public String toString() { + return "TrustedOrigins{inner=" + origin + '}'; + } + + public Origin getOrigin() { + return origin; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/World.java b/src/main/java/org/biscuitsec/biscuit/datalog/World.java index 74aa2c70..2ebdbe9f 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/World.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/World.java @@ -1,158 +1,167 @@ package org.biscuitsec.biscuit.datalog; -import org.biscuitsec.biscuit.error.Error; import io.vavr.Tuple2; import io.vavr.control.Either; - import java.io.Serializable; import java.time.Instant; -import java.util.*; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; import java.util.function.Supplier; import java.util.stream.Stream; +import org.biscuitsec.biscuit.error.Error; -public class World implements Serializable { - private final FactSet facts; - private final RuleSet rules; - - public void add_fact(final Origin origin, final Fact fact) { - this.facts.add(origin, fact); - } - - - public void add_rule(Long origin, TrustedOrigins scope, Rule rule) { - this.rules.add(origin, scope, rule); - } +public final class World implements Serializable { + private final FactSet facts; + private final RuleSet rules; - public void clearRules() { - this.rules.clear(); - } + public void addFact(final Origin origin, final Fact fact) { + this.facts.add(origin, fact); + } - public void run(final SymbolTable symbols) throws Error { - this.run(new RunLimits(), symbols); - } + public void addRule(Long origin, TrustedOrigins scope, Rule rule) { + this.rules.add(origin, scope, rule); + } - public void run(RunLimits limits, final SymbolTable symbols) throws Error { - int iterations = 0; - Instant limit = Instant.now().plus(limits.maxTime); + public void clearRules() { + this.rules.clear(); + } - while(true) { - final FactSet newFacts = new FactSet(); + public void run(final SymbolTable symbolTable) throws Error { + this.run(new RunLimits(), symbolTable); + } - for(Map.Entry>> entry: this.rules.rules.entrySet()) { - for(Tuple2 t: entry.getValue()) { - Supplier>> factsSupplier = () -> this.facts.stream(entry.getKey()); + public void run(RunLimits limits, final SymbolTable symbolTable) throws Error { + int iterations = 0; + Instant limit = Instant.now().plus(limits.getMaxTime()); - Stream>> stream = t._2.apply(factsSupplier, t._1, symbols); - for (Iterator>> it = stream.iterator(); it.hasNext(); ) { - Either> res = it.next(); - if(Instant.now().compareTo(limit) >= 0) { - throw new Error.Timeout(); - } + while (true) { + final FactSet newFacts = new FactSet(); - if(res.isRight()) { - Tuple2 t2 = res.get(); - newFacts.add(t2._1, t2._2); - } else { - throw res.getLeft(); - } - } + for (Map.Entry>> entry : + this.rules.getRules().entrySet()) { + for (Tuple2 t : entry.getValue()) { + Supplier>> factsSupplier = + () -> this.facts.stream(entry.getKey()); + + Stream>> stream = + t._2.apply(factsSupplier, t._1, symbolTable); + for (Iterator>> it = stream.iterator(); + it.hasNext(); ) { + Either> res = it.next(); + if (Instant.now().compareTo(limit) >= 0) { + throw new Error.Timeout(); } - } - - final int len = this.facts.size(); - this.facts.merge(newFacts); - - if (this.facts.size() == len) { - return ; - } - if (this.facts.size() >= limits.maxFacts) { - throw new Error.TooManyFacts(); - } - - iterations += 1; - if(iterations >= limits.maxIterations) { - throw new Error.TooManyIterations(); - } + if (res.isRight()) { + Tuple2 t2 = res.get(); + newFacts.add(t2._1, t2._2); + } else { + throw res.getLeft(); + } + } + } } - } - public final FactSet facts() { - return this.facts; - } + final int len = this.facts.size(); + this.facts.merge(newFacts); - public RuleSet rules() { return this.rules; } - - public final FactSet query_rule(final Rule rule, Long origin, TrustedOrigins scope, SymbolTable symbols) throws Error { - final FactSet newFacts = new FactSet(); - - Supplier>> factsSupplier = () -> this.facts.stream(scope); - - Stream>> stream = rule.apply(factsSupplier, origin, symbols); - for (Iterator>> it = stream.iterator(); it.hasNext(); ) { - Either> res = it.next(); - - if (res.isRight()) { - Tuple2 t2 = res.get(); - newFacts.add(t2._1, t2._2); - } else { - throw res.getLeft(); - } + if (this.facts.size() == len) { + return; } - return newFacts; - } - - public final boolean query_match(final Rule rule, Long origin, TrustedOrigins scope, SymbolTable symbols) throws Error { - return rule.find_match(this.facts, origin, scope, symbols); - } - - public final boolean query_match_all(final Rule rule, TrustedOrigins scope, SymbolTable symbols) throws Error { - return rule.check_match_all(this.facts, scope, symbols); - } + if (this.facts.size() >= limits.getMaxFacts()) { + throw new Error.TooManyFacts(); + } + iterations += 1; + if (iterations >= limits.getMaxIterations()) { + throw new Error.TooManyIterations(); + } + } + } - public World() { - this.facts = new FactSet(); - this.rules = new RuleSet(); - } + public FactSet getFacts() { + return this.facts; + } - public World(FactSet facts) { - this.facts = facts.clone(); - this.rules = new RuleSet(); - } + public RuleSet getRules() { + return this.rules; + } - public World(FactSet facts, RuleSet rules) { - this.facts = facts.clone(); - this.rules = rules.clone(); - } + public FactSet queryRule(final Rule rule, Long origin, TrustedOrigins scope, SymbolTable symbolTable) + throws Error { + final FactSet newFacts = new FactSet(); - public World(World w) { - this.facts = w.facts.clone(); - this.rules = w.rules.clone(); - } + Supplier>> factsSupplier = () -> this.facts.stream(scope); - public String print(SymbolTable symbol_table) { - StringBuilder s = new StringBuilder(); + Stream>> stream = rule.apply(factsSupplier, origin, symbolTable); + for (Iterator>> it = stream.iterator(); it.hasNext(); ) { + Either> res = it.next(); - s.append("World {\n\t\tfacts: ["); - for(Map.Entry> entry: this.facts.facts().entrySet()) { - s.append("\n\t\t\t"+entry.getKey()+":"); - for(Fact f: entry.getValue()) { - s.append("\n\t\t\t\t"); - s.append(symbol_table.print_fact(f)); - } + if (res.isRight()) { + Tuple2 t2 = res.get(); + newFacts.add(t2._1, t2._2); + } else { + throw res.getLeft(); + } + } + + return newFacts; + } + + public boolean queryMatch(final Rule rule, Long origin, TrustedOrigins scope, SymbolTable symbolTable) + throws Error { + return rule.findMatch(this.facts, origin, scope, symbolTable); + } + + public boolean queryMatchAll(final Rule rule, TrustedOrigins scope, SymbolTable symbolTable) + throws Error { + return rule.checkMatchAll(this.facts, scope, symbolTable); + } + + public World() { + this.facts = new FactSet(); + this.rules = new RuleSet(); + } + + public World(FactSet facts) { + this.facts = facts.clone(); + this.rules = new RuleSet(); + } + + public World(FactSet facts, RuleSet rules) { + this.facts = facts.clone(); + this.rules = rules.clone(); + } + + public World(World w) { + this.facts = w.facts.clone(); + this.rules = w.rules.clone(); + } + + public String print(SymbolTable symbolTable) { + StringBuilder s = new StringBuilder(); + + s.append("World {\n\t\tfacts: ["); + for (Map.Entry> entry : this.facts.facts().entrySet()) { + s.append("\n\t\t\t" + entry.getKey() + ":"); + for (Fact f : entry.getValue()) { + s.append("\n\t\t\t\t"); + s.append(symbolTable.formatFact(f)); } + } - s.append("\n\t\t]\n\t\trules: ["); - for (Iterator it = this.rules.stream().iterator(); it.hasNext(); ) { - Rule r = it.next(); - s.append("\n\t\t\t"); - s.append(symbol_table.print_rule(r)); - } + s.append("\n\t\t]\n\t\trules: ["); + for (Iterator it = this.rules.stream().iterator(); it.hasNext(); ) { + Rule r = it.next(); + s.append("\n\t\t\t"); + s.append(symbolTable.formatRule(r)); + } - s.append("\n\t\t]\n\t}"); + s.append("\n\t\t]\n\t}"); - return s.toString(); - } + return s.toString(); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Expression.java b/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Expression.java index d691ac07..964f4937 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Expression.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Expression.java @@ -1,100 +1,106 @@ package org.biscuitsec.biscuit.datalog.expressions; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; +import io.vavr.control.Either; +import io.vavr.control.Option; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Map; +import java.util.Objects; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; import org.biscuitsec.biscuit.datalog.Term; -import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; -import io.vavr.control.Either; -import io.vavr.control.Option; -import java.util.*; +public final class Expression { + private final ArrayList ops; -import static io.vavr.API.Left; -import static io.vavr.API.Right; + public Expression(ArrayList ops) { + this.ops = ops; + } -public class Expression { - private final ArrayList ops; + public ArrayList getOps() { + return ops; + } - public Expression(ArrayList ops) { - this.ops = ops; + // FIXME: should return a Result + public Term evaluate(Map variables, TemporarySymbolTable temporarySymbolTable) + throws Error.Execution { + Deque stack = new ArrayDeque(16); // Default value + for (Op op : ops) { + op.evaluate(stack, variables, temporarySymbolTable); } - - public ArrayList getOps() { - return ops; + if (stack.size() == 1) { + return stack.pop(); + } else { + throw new Error.Execution(this, "execution"); } + } - //FIXME: should return a Result - public Term evaluate(Map variables, TemporarySymbolTable symbols) throws Error.Execution { - Deque stack = new ArrayDeque(16); //Default value - for(Op op: ops){ - op.evaluate(stack,variables, symbols); - } - if(stack.size() == 1){ - return stack.pop(); - } else { - throw new Error.Execution(this, "execution"); - } + public Option print(SymbolTable symbolTable) { + Deque stack = new ArrayDeque<>(); + for (Op op : ops) { + op.print(stack, symbolTable); } - - public Option print(SymbolTable symbols) { - Deque stack = new ArrayDeque<>(); - for (Op op : ops){ - op.print(stack, symbols); - } - if(stack.size() == 1){ - return Option.some(stack.remove()); - } else { - return Option.none(); - } + if (stack.size() == 1) { + return Option.some(stack.remove()); + } else { + return Option.none(); } + } - public Schema.ExpressionV2 serialize() { - Schema.ExpressionV2.Builder b = Schema.ExpressionV2.newBuilder(); + public Schema.ExpressionV2 serialize() { + Schema.ExpressionV2.Builder b = Schema.ExpressionV2.newBuilder(); - for(Op op: this.ops) { - b.addOps(op.serialize()); - } - - return b.build(); + for (Op op : this.ops) { + b.addOps(op.serialize()); } - static public Either deserializeV2(Schema.ExpressionV2 e) { - ArrayList ops = new ArrayList<>(); + return b.build(); + } - for(Schema.Op op: e.getOpsList()) { - Either res = Op.deserializeV2(op); + public static Either deserializeV2(Schema.ExpressionV2 e) { + ArrayList ops = new ArrayList<>(); - if(res.isLeft()) { - Error.FormatError err = res.getLeft(); - return Left(err); - } else { - ops.add(res.get()); - } - } + for (Schema.Op op : e.getOpsList()) { + Either res = Op.deserializeV2(op); - return Right(new Expression(ops)); + if (res.isLeft()) { + Error.FormatError err = res.getLeft(); + return Left(err); + } else { + ops.add(res.get()); + } } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Expression that = (Expression) o; + return Right(new Expression(ops)); + } - return Objects.equals(ops, that.ops); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return ops != null ? ops.hashCode() : 0; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public String toString() { - return "Expression{" + - "ops=" + ops + - '}'; - } + Expression that = (Expression) o; + + return Objects.equals(ops, that.ops); + } + + @Override + public int hashCode() { + return ops != null ? ops.hashCode() : 0; + } + + @Override + public String toString() { + return "Expression{ops=" + ops + '}'; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Op.java b/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Op.java index 42232f29..8353381e 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Op.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/expressions/Op.java @@ -1,780 +1,833 @@ package org.biscuitsec.biscuit.datalog.expressions; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; -import org.biscuitsec.biscuit.datalog.Term; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; import com.google.re2j.Matcher; import com.google.re2j.Pattern; import io.vavr.control.Either; import io.vavr.control.Option; - import java.io.UnsupportedEncodingException; -import java.util.*; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; +import org.biscuitsec.biscuit.datalog.Term; +import org.biscuitsec.biscuit.error.Error; public abstract class Op { - public abstract void evaluate(Deque stack, Map variables, TemporarySymbolTable symbols) throws Error.Execution; - - public abstract String print(Deque stack, SymbolTable symbols); + public abstract void evaluate( + Deque stack, Map variables, TemporarySymbolTable temporarySymbolTable) + throws Error.Execution; + + public abstract String print(Deque stack, SymbolTable symbols); + + public abstract Schema.Op serialize(); + + public static Either deserializeV2(Schema.Op op) { + if (op.hasValue()) { + return Term.deserializeEnumV2(op.getValue()).map(v -> new Op.Value(v)); + } else if (op.hasUnary()) { + return Op.Unary.deserializeV2(op.getUnary()); + } else if (op.hasBinary()) { + return Op.Binary.deserializeV1(op.getBinary()); + } else { + return Left(new Error.FormatError.DeserializationError("invalid unary operation")); + } + } - public abstract Schema.Op serialize(); + public static final class Value extends Op { + private final Term value; - static public Either deserializeV2(Schema.Op op) { - if (op.hasValue()) { - return Term.deserialize_enumV2(op.getValue()).map(v -> new Op.Value(v)); - } else if (op.hasUnary()) { - return Op.Unary.deserializeV2(op.getUnary()); - } else if (op.hasBinary()) { - return Op.Binary.deserializeV1(op.getBinary()); - } else { - return Left(new Error.FormatError.DeserializationError("invalid unary operation")); - } + public Value(Term value) { + this.value = value; } - public final static class Value extends Op { - private final Term value; + public Term getValue() { + return value; + } - public Value(Term value) { - this.value = value; + @Override + public void evaluate(Deque stack, Map variables, TemporarySymbolTable temporarySymbolTable) + throws Error.Execution { + if (value instanceof Term.Variable) { + Term.Variable var = (Term.Variable) value; + Term valueVar = variables.get(var.value()); + if (valueVar != null) { + stack.push(valueVar); + } else { + throw new Error.Execution("cannot find a variable for index " + value); } + } else { + stack.push(value); + } + } - public Term getValue() { - return value; - } + @Override + public String print(Deque stack, SymbolTable symbolTable) { + String s = symbolTable.formatTerm(value); + stack.push(s); + return s; + } - @Override - public void evaluate(Deque stack, Map variables, TemporarySymbolTable symbols) throws Error.Execution { - if (value instanceof Term.Variable) { - Term.Variable var = (Term.Variable) value; - Term valueVar = variables.get(var.value()); - if (valueVar != null) { - stack.push(valueVar); - } else { - throw new Error.Execution( "cannot find a variable for index "+value); - } - } else { - stack.push(value); - } + @Override + public Schema.Op serialize() { + Schema.Op.Builder b = Schema.Op.newBuilder(); - } + b.setValue(this.value.serialize()); - @Override - public String print(Deque stack, SymbolTable symbols) { - String s = symbols.print_term(value); - stack.push(s); - return s; - } + return b.build(); + } - @Override - public Schema.Op serialize() { - Schema.Op.Builder b = Schema.Op.newBuilder(); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - b.setValue(this.value.serialize()); + Value value1 = (Value) o; - return b.build(); - } + return value.equals(value1.value); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public int hashCode() { + return value.hashCode(); + } - Value value1 = (Value) o; + @Override + public String toString() { + return "Value(" + value + ')'; + } + } - return value.equals(value1.value); - } + public enum UnaryOp { + Negate, + Parens, + Length, + } - @Override - public int hashCode() { - return value.hashCode(); - } + public static final class Unary extends Op { + private final UnaryOp op; - @Override - public String toString() { - return "Value(" + value + ')'; - } + public Unary(UnaryOp op) { + this.op = op; } - public enum UnaryOp { - Negate, - Parens, - Length, + public UnaryOp getOp() { + return op; } - public final static class Unary extends Op { - private final UnaryOp op; - - public Unary(UnaryOp op) { - this.op = op; - } - - public UnaryOp getOp() { - return op; - } - - @Override - public void evaluate(Deque stack, Map variables, TemporarySymbolTable symbols) throws Error.Execution { - Term value = stack.pop(); - switch (this.op) { - case Negate: - if (value instanceof Term.Bool) { - Term.Bool b = (Term.Bool) value; - stack.push(new Term.Bool(!b.value())); - } else { - throw new Error.Execution("invalid type for negate op, expected boolean"); - } - break; - case Parens: - stack.push(value); - break; - case Length: - if (value instanceof Term.Str) { - Option s = symbols.get_s((int)((Term.Str) value).value()); - if(s.isEmpty()) { - throw new Error.Execution("string not found in symbols for id"+value); - } else { - try { - stack.push(new Term.Integer(s.get().getBytes("UTF-8").length)); - } catch (UnsupportedEncodingException e) { - throw new Error.Execution("cannot calculate string length: "+e.toString()); - } - } - } else if (value instanceof Term.Bytes) { - stack.push(new Term.Integer(((Term.Bytes) value).value().length)); - } else if (value instanceof Term.Set) { - stack.push(new Term.Integer(((Term.Set) value).value().size())); - } else { - throw new Error.Execution("invalid type for length op"); - } + @Override + public void evaluate(Deque stack, Map variables, TemporarySymbolTable temporarySymbolTable) + throws Error.Execution { + Term value = stack.pop(); + switch (this.op) { + case Negate: + if (value instanceof Term.Bool) { + Term.Bool b = (Term.Bool) value; + stack.push(new Term.Bool(!b.value())); + } else { + throw new Error.Execution("invalid type for negate op, expected boolean"); + } + break; + case Parens: + stack.push(value); + break; + case Length: + if (value instanceof Term.Str) { + Option s = temporarySymbolTable.getSymbol((int) ((Term.Str) value).value()); + if (s.isEmpty()) { + throw new Error.Execution("string not found in symbols for id" + value); + } else { + try { + stack.push(new Term.Integer(s.get().getBytes("UTF-8").length)); + } catch (UnsupportedEncodingException e) { + throw new Error.Execution("cannot calculate string length: " + e.toString()); + } } - } + } else if (value instanceof Term.Bytes) { + stack.push(new Term.Integer(((Term.Bytes) value).value().length)); + } else if (value instanceof Term.Set) { + stack.push(new Term.Integer(((Term.Set) value).value().size())); + } else { + throw new Error.Execution("invalid type for length op"); + } + break; + default: + throw new Error.Execution("invalid type for length op"); + } + } - @Override - public String print(Deque stack, SymbolTable symbols) { - String prec = stack.pop(); - String _s = ""; - switch (this.op) { - case Negate: - _s = "!" + prec; - stack.push(_s); - break; - case Parens: - _s = "(" + prec + ")"; - stack.push(_s); - break; - case Length: - _s = prec+".length()"; - stack.push(_s); - break; - } - return _s; - } + @Override + public String print(Deque stack, SymbolTable symbolTable) { + String prec = stack.pop(); + String s = ""; + switch (this.op) { + case Negate: + s = "!" + prec; + stack.push(s); + break; + case Parens: + s = "(" + prec + ")"; + stack.push(s); + break; + case Length: + s = prec + ".length()"; + stack.push(s); + break; + default: + } + return s; + } - @Override - public Schema.Op serialize() { - Schema.Op.Builder b = Schema.Op.newBuilder(); - - Schema.OpUnary.Builder b1 = Schema.OpUnary.newBuilder(); - - switch (this.op) { - case Negate: - b1.setKind(Schema.OpUnary.Kind.Negate); - break; - case Parens: - b1.setKind(Schema.OpUnary.Kind.Parens); - break; - case Length: - b1.setKind(Schema.OpUnary.Kind.Length); - break; - } + @Override + public Schema.Op serialize() { + Schema.Op.Builder b = Schema.Op.newBuilder(); - b.setUnary(b1.build()); + Schema.OpUnary.Builder b1 = Schema.OpUnary.newBuilder(); - return b.build(); - } + switch (this.op) { + case Negate: + b1.setKind(Schema.OpUnary.Kind.Negate); + break; + case Parens: + b1.setKind(Schema.OpUnary.Kind.Parens); + break; + case Length: + b1.setKind(Schema.OpUnary.Kind.Length); + break; + default: + } - static public Either deserializeV2(Schema.OpUnary op) { - switch (op.getKind()) { - case Negate: - return Right(new Op.Unary(UnaryOp.Negate)); - case Parens: - return Right(new Op.Unary(UnaryOp.Parens)); - case Length: - return Right(new Op.Unary(UnaryOp.Length)); - } + b.setUnary(b1.build()); - return Left(new Error.FormatError.DeserializationError("invalid unary operation")); - } + return b.build(); + } - @Override - public String toString() { - return "Unary."+op; - } + public static Either deserializeV2(Schema.OpUnary op) { + switch (op.getKind()) { + case Negate: + return Right(new Op.Unary(UnaryOp.Negate)); + case Parens: + return Right(new Op.Unary(UnaryOp.Parens)); + case Length: + return Right(new Op.Unary(UnaryOp.Length)); + default: + } + + return Left(new Error.FormatError.DeserializationError("invalid unary operation")); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public String toString() { + return "Unary." + op; + } - Unary unary = (Unary) o; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - return op == unary.op; - } + Unary unary = (Unary) o; - @Override - public int hashCode() { - return op.hashCode(); - } + return op == unary.op; } - public enum BinaryOp { - LessThan, - GreaterThan, - LessOrEqual, - GreaterOrEqual, - Equal, - NotEqual, - Contains, - Prefix, - Suffix, - Regex, - Add, - Sub, - Mul, - Div, - And, - Or, - Intersection, - Union, - BitwiseAnd, - BitwiseOr, - BitwiseXor, + @Override + public int hashCode() { + return op.hashCode(); + } + } + + public enum BinaryOp { + LessThan, + GreaterThan, + LessOrEqual, + GreaterOrEqual, + Equal, + NotEqual, + Contains, + Prefix, + Suffix, + Regex, + Add, + Sub, + Mul, + Div, + And, + Or, + Intersection, + Union, + BitwiseAnd, + BitwiseOr, + BitwiseXor, + } + + public static final class Binary extends Op { + private final BinaryOp op; + + public Binary(BinaryOp value) { + this.op = value; } - public final static class Binary extends Op { - private final BinaryOp op; - - public Binary(BinaryOp value) { - this.op = value; - } - - public BinaryOp getOp() { - return op; - } + public BinaryOp getOp() { + return op; + } - @Override - public void evaluate(Deque stack, Map variables, TemporarySymbolTable symbols) throws Error.Execution { - Term right = stack.pop(); - Term left = stack.pop(); - - switch (this.op) { - case LessThan: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() < ((Term.Integer) right).value())); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() < ((Term.Date) right).value())); - } - break; - case GreaterThan: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() > ((Term.Integer) right).value())); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() > ((Term.Date) right).value())); - } - break; - case LessOrEqual: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() <= ((Term.Integer) right).value())); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() <= ((Term.Date) right).value())); - } - break; - case GreaterOrEqual: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() >= ((Term.Integer) right).value())); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() >= ((Term.Date) right).value())); - } - break; - case Equal: - if (right instanceof Term.Bool && left instanceof Term.Bool) { - stack.push(new Term.Bool(((Term.Bool) left).value() == ((Term.Bool) right).value())); - } - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() == ((Term.Integer) right).value())); - } - if (right instanceof Term.Str && left instanceof Term.Str) { - stack.push(new Term.Bool(((Term.Str) left).value() == ((Term.Str) right).value())); - } - if (right instanceof Term.Bytes && left instanceof Term.Bytes) { - stack.push(new Term.Bool(Arrays.equals(((Term.Bytes) left).value(), (((Term.Bytes) right).value())))); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() == ((Term.Date) right).value())); - } - if (right instanceof Term.Set && left instanceof Term.Set) { - Set leftSet = ((Term.Set) left).value(); - Set rightSet = ((Term.Set) right).value(); - stack.push(new Term.Bool( leftSet.size() == rightSet.size() && leftSet.containsAll(rightSet))); - } - break; - case NotEqual: - if (right instanceof Term.Bool && left instanceof Term.Bool) { - stack.push(new Term.Bool(((Term.Bool) left).value() == ((Term.Bool) right).value())); - } - if (right instanceof Term.Integer && left instanceof Term.Integer) { - stack.push(new Term.Bool(((Term.Integer) left).value() != ((Term.Integer) right).value())); - } - if (right instanceof Term.Str && left instanceof Term.Str) { - stack.push(new Term.Bool(((Term.Str) left).value() != ((Term.Str) right).value())); - } - if (right instanceof Term.Bytes && left instanceof Term.Bytes) { - stack.push(new Term.Bool(!Arrays.equals(((Term.Bytes) left).value(), (((Term.Bytes) right).value())))); - } - if (right instanceof Term.Date && left instanceof Term.Date) { - stack.push(new Term.Bool(((Term.Date) left).value() != ((Term.Date) right).value())); - } - if (right instanceof Term.Set && left instanceof Term.Set) { - Set leftSet = ((Term.Set) left).value(); - Set rightSet = ((Term.Set) right).value(); - stack.push(new Term.Bool( leftSet.size() != rightSet.size() || !leftSet.containsAll(rightSet))); - } - break; - case Contains: - if (left instanceof Term.Set && - (right instanceof Term.Integer || - right instanceof Term.Str || - right instanceof Term.Bytes || - right instanceof Term.Date || - right instanceof Term.Bool)) { - - stack.push(new Term.Bool(((Term.Set) left).value().contains(right))); - } - if (right instanceof Term.Set && left instanceof Term.Set) { - Set leftSet = ((Term.Set) left).value(); - Set rightSet = ((Term.Set) right).value(); - stack.push(new Term.Bool(leftSet.containsAll(rightSet))); - } - if (left instanceof Term.Str && right instanceof Term.Str) { - Option left_s = symbols.get_s((int)((Term.Str) left).value()); - Option right_s = symbols.get_s((int)((Term.Str) right).value()); - - if(left_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) left).value()); - } - if(right_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) right).value()); - } - - - stack.push(new Term.Bool(left_s.get().contains(right_s.get()))); - } - break; - case Prefix: - if (right instanceof Term.Str && left instanceof Term.Str) { - Option left_s = symbols.get_s((int)((Term.Str) left).value()); - Option right_s = symbols.get_s((int)((Term.Str) right).value()); - if(left_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) left).value()); - } - if(right_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) right).value()); - } - - stack.push(new Term.Bool(left_s.get().startsWith(right_s.get()))); - } - break; - case Suffix: - if (right instanceof Term.Str && left instanceof Term.Str) { - Option left_s = symbols.get_s((int)((Term.Str) left).value()); - Option right_s = symbols.get_s((int)((Term.Str) right).value()); - if(left_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) left).value()); - } - if(right_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) right).value()); - } - stack.push(new Term.Bool(left_s.get().endsWith(right_s.get()))); - } - break; - case Regex: - if (right instanceof Term.Str && left instanceof Term.Str) { - Option left_s = symbols.get_s((int)((Term.Str) left).value()); - Option right_s = symbols.get_s((int)((Term.Str) right).value()); - if(left_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) left).value()); - } - if(right_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) right).value()); - } - - Pattern p = Pattern.compile(right_s.get()); - Matcher m = p.matcher(left_s.get()); - stack.push(new Term.Bool(m.find())); - } - break; - case Add: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - try { - stack.push(new Term.Integer( - Math.addExact(((Term.Integer) left).value(), ((Term.Integer) right).value()) - )); - } catch (ArithmeticException e) { - throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); - } - } - if (right instanceof Term.Str && left instanceof Term.Str) { - Option left_s = symbols.get_s((int)((Term.Str) left).value()); - Option right_s = symbols.get_s((int)((Term.Str) right).value()); - - if(left_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) left).value()); - } - if(right_s.isEmpty()) { - throw new Error.Execution("cannot find string in symbols for index "+((Term.Str) right).value()); - } - - String concatenation = left_s.get() + right_s.get(); - long index = symbols.insert(concatenation); - stack.push(new Term.Str(index)); - } - break; - case Sub: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - try { - stack.push(new Term.Integer( - Math.subtractExact(((Term.Integer) left).value(), ((Term.Integer) right).value()) - )); - } catch (ArithmeticException e) { - throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); - } - } - break; - case Mul: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - try { - stack.push(new Term.Integer( - Math.multiplyExact(((Term.Integer) left).value(), ((Term.Integer) right).value()) - )); - } catch (ArithmeticException e) { - throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); - } - } - break; - case Div: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - long rl = ((Term.Integer) right).value(); - if (rl != 0) { - stack.push(new Term.Integer(((Term.Integer) left).value() / rl)); - } - } - break; - case And: - if (right instanceof Term.Bool && left instanceof Term.Bool) { - stack.push(new Term.Bool(((Term.Bool) left).value() && ((Term.Bool) right).value())); - } - break; - case Or: - if (right instanceof Term.Bool && left instanceof Term.Bool) { - stack.push(new Term.Bool(((Term.Bool) left).value() || ((Term.Bool) right).value())); - } - break; - case Intersection: - if (right instanceof Term.Set && left instanceof Term.Set) { - HashSet intersec = new HashSet(); - HashSet _right = ((Term.Set) right).value(); - HashSet _left = ((Term.Set) left).value(); - for (Term _id : _right) { - if (_left.contains(_id)) { - intersec.add(_id); - } - } - stack.push(new Term.Set(intersec)); - } - break; - case Union: - if (right instanceof Term.Set && left instanceof Term.Set) { - HashSet union = new HashSet(); - HashSet _right = ((Term.Set) right).value(); - HashSet _left = ((Term.Set) left).value(); - union.addAll(_right); - union.addAll(_left); - stack.push(new Term.Set(union)); - } - break; - case BitwiseAnd: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - long r = ((Term.Integer) right).value(); - long l = ((Term.Integer) left).value(); - stack.push(new Term.Integer(r & l)); - } - break; - case BitwiseOr: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - long r = ((Term.Integer) right).value(); - long l = ((Term.Integer) left).value(); - stack.push(new Term.Integer(r | l)); - } - break; - case BitwiseXor: - if (right instanceof Term.Integer && left instanceof Term.Integer) { - long r = ((Term.Integer) right).value(); - long l = ((Term.Integer) left).value(); - stack.push(new Term.Integer(r ^ l)); - } - break; - default: - throw new Error.Execution("binary exec error for op"+this); + @Override + public void evaluate(Deque stack, Map variables, TemporarySymbolTable temporarySymbolTable) + throws Error.Execution { + Term right = stack.pop(); + Term left = stack.pop(); + + switch (this.op) { + case LessThan: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() < ((Term.Integer) right).value())); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() < ((Term.Date) right).value())); + } + break; + case GreaterThan: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() > ((Term.Integer) right).value())); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() > ((Term.Date) right).value())); + } + break; + case LessOrEqual: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() <= ((Term.Integer) right).value())); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() <= ((Term.Date) right).value())); + } + break; + case GreaterOrEqual: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() >= ((Term.Integer) right).value())); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() >= ((Term.Date) right).value())); + } + break; + case Equal: + if (right instanceof Term.Bool && left instanceof Term.Bool) { + stack.push(new Term.Bool(((Term.Bool) left).value() == ((Term.Bool) right).value())); + } + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() == ((Term.Integer) right).value())); + } + if (right instanceof Term.Str && left instanceof Term.Str) { + stack.push(new Term.Bool(((Term.Str) left).value() == ((Term.Str) right).value())); + } + if (right instanceof Term.Bytes && left instanceof Term.Bytes) { + stack.push( + new Term.Bool( + Arrays.equals(((Term.Bytes) left).value(), (((Term.Bytes) right).value())))); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() == ((Term.Date) right).value())); + } + if (right instanceof Term.Set && left instanceof Term.Set) { + Set leftSet = ((Term.Set) left).value(); + Set rightSet = ((Term.Set) right).value(); + stack.push( + new Term.Bool(leftSet.size() == rightSet.size() && leftSet.containsAll(rightSet))); + } + break; + case NotEqual: + if (right instanceof Term.Bool && left instanceof Term.Bool) { + stack.push(new Term.Bool(((Term.Bool) left).value() == ((Term.Bool) right).value())); + } + if (right instanceof Term.Integer && left instanceof Term.Integer) { + stack.push( + new Term.Bool(((Term.Integer) left).value() != ((Term.Integer) right).value())); + } + if (right instanceof Term.Str && left instanceof Term.Str) { + stack.push(new Term.Bool(((Term.Str) left).value() != ((Term.Str) right).value())); + } + if (right instanceof Term.Bytes && left instanceof Term.Bytes) { + stack.push( + new Term.Bool( + !Arrays.equals(((Term.Bytes) left).value(), (((Term.Bytes) right).value())))); + } + if (right instanceof Term.Date && left instanceof Term.Date) { + stack.push(new Term.Bool(((Term.Date) left).value() != ((Term.Date) right).value())); + } + if (right instanceof Term.Set && left instanceof Term.Set) { + Set leftSet = ((Term.Set) left).value(); + Set rightSet = ((Term.Set) right).value(); + stack.push( + new Term.Bool(leftSet.size() != rightSet.size() || !leftSet.containsAll(rightSet))); + } + break; + case Contains: + if (left instanceof Term.Set + && (right instanceof Term.Integer + || right instanceof Term.Str + || right instanceof Term.Bytes + || right instanceof Term.Date + || right instanceof Term.Bool)) { + + stack.push(new Term.Bool(((Term.Set) left).value().contains(right))); + } + if (right instanceof Term.Set && left instanceof Term.Set) { + Set leftSet = ((Term.Set) left).value(); + Set rightSet = ((Term.Set) right).value(); + stack.push(new Term.Bool(leftSet.containsAll(rightSet))); + } + if (left instanceof Term.Str && right instanceof Term.Str) { + Option leftS = temporarySymbolTable.getSymbol((int) ((Term.Str) left).value()); + Option rightS = temporarySymbolTable.getSymbol((int) ((Term.Str) right).value()); + + if (leftS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) left).value()); + } + if (rightS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) right).value()); } - } - @Override - public String print(Deque stack, SymbolTable symbols) { - String right = stack.pop(); - String left = stack.pop(); - String _s = ""; - switch (this.op) { - case LessThan: - _s = left + " < " + right; - stack.push(_s); - break; - case GreaterThan: - _s = left + " > " + right; - stack.push(_s); - break; - case LessOrEqual: - _s = left + " <= " + right; - stack.push(_s); - break; - case GreaterOrEqual: - _s = left + " >= " + right; - stack.push(_s); - break; - case Equal: - _s = left + " == " + right; - stack.push(_s); - break; - case NotEqual: - _s = left + " != " + right; - stack.push(_s); - break; - case Contains: - _s = left + ".contains(" + right + ")"; - stack.push(_s); - break; - case Prefix: - _s = left + ".starts_with(" + right + ")"; - stack.push(_s); - break; - case Suffix: - _s = left + ".ends_with(" + right + ")"; - stack.push(_s); - break; - case Regex: - _s = left + ".matches(" + right + ")"; - stack.push(_s); - break; - case Add: - _s = left + " + " + right; - stack.push(_s); - break; - case Sub: - _s = left + " - " + right; - stack.push(_s); - break; - case Mul: - _s = left + " * " + right; - stack.push(_s); - break; - case Div: - _s = left + " / " + right; - stack.push(_s); - break; - case And: - _s = left + " && " + right; - stack.push(_s); - break; - case Or: - _s = left + " || " + right; - stack.push(_s); - break; - case Intersection: - _s = left + ".intersection("+right+")"; - stack.push(_s); - break; - case Union: - _s = left + ".union("+right+")"; - stack.push(_s); - break; - case BitwiseAnd: - _s = left + " & " + right; - stack.push(_s); - break; - case BitwiseOr: - _s = left + " | " + right; - stack.push(_s); - break; - case BitwiseXor: - _s = left + " ^ " + right; - stack.push(_s); - break; + stack.push(new Term.Bool(leftS.get().contains(rightS.get()))); + } + break; + case Prefix: + if (right instanceof Term.Str && left instanceof Term.Str) { + Option leftS = temporarySymbolTable.getSymbol((int) ((Term.Str) left).value()); + Option rightS = temporarySymbolTable.getSymbol((int) ((Term.Str) right).value()); + if (leftS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) left).value()); + } + if (rightS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) right).value()); } - return _s; - } + stack.push(new Term.Bool(leftS.get().startsWith(rightS.get()))); + } + break; + case Suffix: + if (right instanceof Term.Str && left instanceof Term.Str) { + Option leftS = temporarySymbolTable.getSymbol((int) ((Term.Str) left).value()); + Option rightS = temporarySymbolTable.getSymbol((int) ((Term.Str) right).value()); + if (leftS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) left).value()); + } + if (rightS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) right).value()); + } + stack.push(new Term.Bool(leftS.get().endsWith(rightS.get()))); + } + break; + case Regex: + if (right instanceof Term.Str && left instanceof Term.Str) { + Option leftS = temporarySymbolTable.getSymbol((int) ((Term.Str) left).value()); + Option rightS = temporarySymbolTable.getSymbol((int) ((Term.Str) right).value()); + if (leftS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) left).value()); + } + if (rightS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) right).value()); + } - @Override - public Schema.Op serialize() { - Schema.Op.Builder b = Schema.Op.newBuilder(); - - Schema.OpBinary.Builder b1 = Schema.OpBinary.newBuilder(); - - switch (this.op) { - case LessThan: - b1.setKind(Schema.OpBinary.Kind.LessThan); - break; - case GreaterThan: - b1.setKind(Schema.OpBinary.Kind.GreaterThan); - break; - case LessOrEqual: - b1.setKind(Schema.OpBinary.Kind.LessOrEqual); - break; - case GreaterOrEqual: - b1.setKind(Schema.OpBinary.Kind.GreaterOrEqual); - break; - case Equal: - b1.setKind(Schema.OpBinary.Kind.Equal); - break; - case NotEqual: - b1.setKind(Schema.OpBinary.Kind.NotEqual); - break; - case Contains: - b1.setKind(Schema.OpBinary.Kind.Contains); - break; - case Prefix: - b1.setKind(Schema.OpBinary.Kind.Prefix); - break; - case Suffix: - b1.setKind(Schema.OpBinary.Kind.Suffix); - break; - case Regex: - b1.setKind(Schema.OpBinary.Kind.Regex); - break; - case Add: - b1.setKind(Schema.OpBinary.Kind.Add); - break; - case Sub: - b1.setKind(Schema.OpBinary.Kind.Sub); - break; - case Mul: - b1.setKind(Schema.OpBinary.Kind.Mul); - break; - case Div: - b1.setKind(Schema.OpBinary.Kind.Div); - break; - case And: - b1.setKind(Schema.OpBinary.Kind.And); - break; - case Or: - b1.setKind(Schema.OpBinary.Kind.Or); - break; - case Intersection: - b1.setKind(Schema.OpBinary.Kind.Intersection); - break; - case Union: - b1.setKind(Schema.OpBinary.Kind.Union); - break; - case BitwiseAnd: - b1.setKind(Schema.OpBinary.Kind.BitwiseAnd); - break; - case BitwiseOr: - b1.setKind(Schema.OpBinary.Kind.BitwiseOr); - break; - case BitwiseXor: - b1.setKind(Schema.OpBinary.Kind.BitwiseXor); - break; + Pattern p = Pattern.compile(rightS.get()); + Matcher m = p.matcher(leftS.get()); + stack.push(new Term.Bool(m.find())); + } + break; + case Add: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + try { + stack.push( + new Term.Integer( + Math.addExact( + ((Term.Integer) left).value(), ((Term.Integer) right).value()))); + } catch (ArithmeticException e) { + throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); + } + } + if (right instanceof Term.Str && left instanceof Term.Str) { + Option leftS = temporarySymbolTable.getSymbol((int) ((Term.Str) left).value()); + Option rightS = temporarySymbolTable.getSymbol((int) ((Term.Str) right).value()); + + if (leftS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) left).value()); + } + if (rightS.isEmpty()) { + throw new Error.Execution( + "cannot find string in symbols for index " + ((Term.Str) right).value()); } - b.setBinary(b1.build()); + String concatenation = leftS.get() + rightS.get(); + long index = temporarySymbolTable.insert(concatenation); + stack.push(new Term.Str(index)); + } + break; + case Sub: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + try { + stack.push( + new Term.Integer( + Math.subtractExact( + ((Term.Integer) left).value(), ((Term.Integer) right).value()))); + } catch (ArithmeticException e) { + throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); + } + } + break; + case Mul: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + try { + stack.push( + new Term.Integer( + Math.multiplyExact( + ((Term.Integer) left).value(), ((Term.Integer) right).value()))); + } catch (ArithmeticException e) { + throw new Error.Execution(Error.Execution.Kind.Overflow, "overflow"); + } + } + break; + case Div: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + long rl = ((Term.Integer) right).value(); + if (rl != 0) { + stack.push(new Term.Integer(((Term.Integer) left).value() / rl)); + } + } + break; + case And: + if (right instanceof Term.Bool && left instanceof Term.Bool) { + stack.push(new Term.Bool(((Term.Bool) left).value() && ((Term.Bool) right).value())); + } + break; + case Or: + if (right instanceof Term.Bool && left instanceof Term.Bool) { + stack.push(new Term.Bool(((Term.Bool) left).value() || ((Term.Bool) right).value())); + } + break; + case Intersection: + if (right instanceof Term.Set && left instanceof Term.Set) { + HashSet intersec = new HashSet(); + HashSet setRight = ((Term.Set) right).value(); + HashSet setLeft = ((Term.Set) left).value(); + for (Term locId : setRight) { + if (setLeft.contains(locId)) { + intersec.add(locId); + } + } + stack.push(new Term.Set(intersec)); + } + break; + case Union: + if (right instanceof Term.Set && left instanceof Term.Set) { + HashSet union = new HashSet(); + HashSet setRight = ((Term.Set) right).value(); + HashSet setLeft = ((Term.Set) left).value(); + union.addAll(setRight); + union.addAll(setLeft); + stack.push(new Term.Set(union)); + } + break; + case BitwiseAnd: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + long r = ((Term.Integer) right).value(); + long l = ((Term.Integer) left).value(); + stack.push(new Term.Integer(r & l)); + } + break; + case BitwiseOr: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + long r = ((Term.Integer) right).value(); + long l = ((Term.Integer) left).value(); + stack.push(new Term.Integer(r | l)); + } + break; + case BitwiseXor: + if (right instanceof Term.Integer && left instanceof Term.Integer) { + long r = ((Term.Integer) right).value(); + long l = ((Term.Integer) left).value(); + stack.push(new Term.Integer(r ^ l)); + } + break; + default: + throw new Error.Execution("binary exec error for op" + this); + } + } - return b.build(); - } + @Override + public String print(Deque stack, SymbolTable symbolTable) { + String right = stack.pop(); + String left = stack.pop(); + String s = ""; + switch (this.op) { + case LessThan: + s = left + " < " + right; + stack.push(s); + break; + case GreaterThan: + s = left + " > " + right; + stack.push(s); + break; + case LessOrEqual: + s = left + " <= " + right; + stack.push(s); + break; + case GreaterOrEqual: + s = left + " >= " + right; + stack.push(s); + break; + case Equal: + s = left + " == " + right; + stack.push(s); + break; + case NotEqual: + s = left + " != " + right; + stack.push(s); + break; + case Contains: + s = left + ".contains(" + right + ")"; + stack.push(s); + break; + case Prefix: + s = left + ".starts_with(" + right + ")"; + stack.push(s); + break; + case Suffix: + s = left + ".ends_with(" + right + ")"; + stack.push(s); + break; + case Regex: + s = left + ".matches(" + right + ")"; + stack.push(s); + break; + case Add: + s = left + " + " + right; + stack.push(s); + break; + case Sub: + s = left + " - " + right; + stack.push(s); + break; + case Mul: + s = left + " * " + right; + stack.push(s); + break; + case Div: + s = left + " / " + right; + stack.push(s); + break; + case And: + s = left + " && " + right; + stack.push(s); + break; + case Or: + s = left + " || " + right; + stack.push(s); + break; + case Intersection: + s = left + ".intersection(" + right + ")"; + stack.push(s); + break; + case Union: + s = left + ".union(" + right + ")"; + stack.push(s); + break; + case BitwiseAnd: + s = left + " & " + right; + stack.push(s); + break; + case BitwiseOr: + s = left + " | " + right; + stack.push(s); + break; + case BitwiseXor: + s = left + " ^ " + right; + stack.push(s); + break; + default: + } + + return s; + } - static public Either deserializeV1(Schema.OpBinary op) { - switch (op.getKind()) { - case LessThan: - return Right(new Op.Binary(BinaryOp.LessThan)); - case GreaterThan: - return Right(new Op.Binary(BinaryOp.GreaterThan)); - case LessOrEqual: - return Right(new Op.Binary(BinaryOp.LessOrEqual)); - case GreaterOrEqual: - return Right(new Op.Binary(BinaryOp.GreaterOrEqual)); - case Equal: - return Right(new Op.Binary(BinaryOp.Equal)); - case NotEqual: - return Right(new Op.Binary(BinaryOp.NotEqual)); - case Contains: - return Right(new Op.Binary(BinaryOp.Contains)); - case Prefix: - return Right(new Op.Binary(BinaryOp.Prefix)); - case Suffix: - return Right(new Op.Binary(BinaryOp.Suffix)); - case Regex: - return Right(new Op.Binary(BinaryOp.Regex)); - case Add: - return Right(new Op.Binary(BinaryOp.Add)); - case Sub: - return Right(new Op.Binary(BinaryOp.Sub)); - case Mul: - return Right(new Op.Binary(BinaryOp.Mul)); - case Div: - return Right(new Op.Binary(BinaryOp.Div)); - case And: - return Right(new Op.Binary(BinaryOp.And)); - case Or: - return Right(new Op.Binary(BinaryOp.Or)); - case Intersection: - return Right(new Op.Binary(BinaryOp.Intersection)); - case Union: - return Right(new Op.Binary(BinaryOp.Union)); - case BitwiseAnd: - return Right(new Op.Binary(BinaryOp.BitwiseAnd)); - case BitwiseOr: - return Right(new Op.Binary(BinaryOp.BitwiseOr)); - case BitwiseXor: - return Right(new Op.Binary(BinaryOp.BitwiseXor)); - } + @Override + public Schema.Op serialize() { + Schema.Op.Builder b = Schema.Op.newBuilder(); + + Schema.OpBinary.Builder b1 = Schema.OpBinary.newBuilder(); + + switch (this.op) { + case LessThan: + b1.setKind(Schema.OpBinary.Kind.LessThan); + break; + case GreaterThan: + b1.setKind(Schema.OpBinary.Kind.GreaterThan); + break; + case LessOrEqual: + b1.setKind(Schema.OpBinary.Kind.LessOrEqual); + break; + case GreaterOrEqual: + b1.setKind(Schema.OpBinary.Kind.GreaterOrEqual); + break; + case Equal: + b1.setKind(Schema.OpBinary.Kind.Equal); + break; + case NotEqual: + b1.setKind(Schema.OpBinary.Kind.NotEqual); + break; + case Contains: + b1.setKind(Schema.OpBinary.Kind.Contains); + break; + case Prefix: + b1.setKind(Schema.OpBinary.Kind.Prefix); + break; + case Suffix: + b1.setKind(Schema.OpBinary.Kind.Suffix); + break; + case Regex: + b1.setKind(Schema.OpBinary.Kind.Regex); + break; + case Add: + b1.setKind(Schema.OpBinary.Kind.Add); + break; + case Sub: + b1.setKind(Schema.OpBinary.Kind.Sub); + break; + case Mul: + b1.setKind(Schema.OpBinary.Kind.Mul); + break; + case Div: + b1.setKind(Schema.OpBinary.Kind.Div); + break; + case And: + b1.setKind(Schema.OpBinary.Kind.And); + break; + case Or: + b1.setKind(Schema.OpBinary.Kind.Or); + break; + case Intersection: + b1.setKind(Schema.OpBinary.Kind.Intersection); + break; + case Union: + b1.setKind(Schema.OpBinary.Kind.Union); + break; + case BitwiseAnd: + b1.setKind(Schema.OpBinary.Kind.BitwiseAnd); + break; + case BitwiseOr: + b1.setKind(Schema.OpBinary.Kind.BitwiseOr); + break; + case BitwiseXor: + b1.setKind(Schema.OpBinary.Kind.BitwiseXor); + break; + default: + } + + b.setBinary(b1.build()); + + return b.build(); + } - return Left(new Error.FormatError.DeserializationError("invalid binary operation: "+op.getKind())); - } + public static Either deserializeV1(Schema.OpBinary op) { + switch (op.getKind()) { + case LessThan: + return Right(new Op.Binary(BinaryOp.LessThan)); + case GreaterThan: + return Right(new Op.Binary(BinaryOp.GreaterThan)); + case LessOrEqual: + return Right(new Op.Binary(BinaryOp.LessOrEqual)); + case GreaterOrEqual: + return Right(new Op.Binary(BinaryOp.GreaterOrEqual)); + case Equal: + return Right(new Op.Binary(BinaryOp.Equal)); + case NotEqual: + return Right(new Op.Binary(BinaryOp.NotEqual)); + case Contains: + return Right(new Op.Binary(BinaryOp.Contains)); + case Prefix: + return Right(new Op.Binary(BinaryOp.Prefix)); + case Suffix: + return Right(new Op.Binary(BinaryOp.Suffix)); + case Regex: + return Right(new Op.Binary(BinaryOp.Regex)); + case Add: + return Right(new Op.Binary(BinaryOp.Add)); + case Sub: + return Right(new Op.Binary(BinaryOp.Sub)); + case Mul: + return Right(new Op.Binary(BinaryOp.Mul)); + case Div: + return Right(new Op.Binary(BinaryOp.Div)); + case And: + return Right(new Op.Binary(BinaryOp.And)); + case Or: + return Right(new Op.Binary(BinaryOp.Or)); + case Intersection: + return Right(new Op.Binary(BinaryOp.Intersection)); + case Union: + return Right(new Op.Binary(BinaryOp.Union)); + case BitwiseAnd: + return Right(new Op.Binary(BinaryOp.BitwiseAnd)); + case BitwiseOr: + return Right(new Op.Binary(BinaryOp.BitwiseOr)); + case BitwiseXor: + return Right(new Op.Binary(BinaryOp.BitwiseXor)); + default: + return Left( + new Error.FormatError.DeserializationError( + "invalid binary operation: " + op.getKind())); + } + } - @Override - public String toString() { - return "Binary."+ op; - } + @Override + public String toString() { + return "Binary." + op; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Binary binary = (Binary) o; + Binary binary = (Binary) o; - return op == binary.op; - } + return op == binary.op; + } - @Override - public int hashCode() { - return op.hashCode(); - } + @Override + public int hashCode() { + return op.hashCode(); } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/datalog/package-info.java b/src/main/java/org/biscuitsec/biscuit/datalog/package-info.java index f8a2fe55..4acedf5e 100644 --- a/src/main/java/org/biscuitsec/biscuit/datalog/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/datalog/package-info.java @@ -1,4 +1,2 @@ -/** - * Implementation of the Datalog engine for the check language - */ -package org.biscuitsec.biscuit.datalog; \ No newline at end of file +/** Implementation of the Datalog engine for the check language */ +package org.biscuitsec.biscuit.datalog; diff --git a/src/main/java/org/biscuitsec/biscuit/error/Error.java b/src/main/java/org/biscuitsec/biscuit/error/Error.java index f1abdda1..d9eb8d48 100644 --- a/src/main/java/org/biscuitsec/biscuit/error/Error.java +++ b/src/main/java/org/biscuitsec/biscuit/error/Error.java @@ -1,632 +1,740 @@ package org.biscuitsec.biscuit.error; -import org.biscuitsec.biscuit.datalog.expressions.Expression; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonPrimitive; import io.vavr.control.Option; - import java.util.List; import java.util.Objects; +import org.biscuitsec.biscuit.datalog.expressions.Expression; public class Error extends Exception { - public Option> failed_checks() { - return Option.none(); + public Option> getFailedChecks() { + return Option.none(); + } + + /** + * Serialize error to JSON + * + * @return json object + */ + public JsonElement toJson() { + return new JsonObject(); + } + + public static final class InternalError extends Error {} + + public static class FormatError extends Error { + + private static JsonElement jsonWrapper(JsonElement e) { + JsonObject root = new JsonObject(); + root.add("Format", e); + return root; } - public JsonElement toJson() { - return new JsonObject(); - } - - public static class InternalError extends Error {} - - public static class FormatError extends Error { + public static class Signature extends FormatError { + private static JsonElement jsonWrapper(JsonElement e) { + JsonObject signature = new JsonObject(); + signature.add("Signature", e); + return FormatError.jsonWrapper(signature); + } - private static JsonElement jsonWrapper(JsonElement e) { - JsonObject root = new JsonObject(); - root.add("Format", e); - return root; - } - public static class Signature extends FormatError { - private static JsonElement jsonWrapper(JsonElement e) { - JsonObject signature = new JsonObject(); - signature.add("Signature", e); - return FormatError.jsonWrapper(signature); - } - public static class InvalidFormat extends Signature { - public InvalidFormat() {} - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - - @Override - public JsonElement toJson() { - return Signature.jsonWrapper(new JsonPrimitive("InvalidFormat")); - } - @Override - public String toString(){ - return "Err(Format(Signature(InvalidFormat)))"; - } - } - public static class InvalidSignature extends Signature { - final public String e; - public InvalidSignature(String e) { - this.e = e; - } - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("InvalidSignature", this.e); - return Signature.jsonWrapper(jo); - } - @Override - public String toString(){ - return "Err(Format(Signature(InvalidFormat(\""+this.e+"\"))))"; - } - } - } - - public static class SealedSignature extends FormatError { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson() { - return FormatError.jsonWrapper(new JsonPrimitive("SealedSignature")); - } - @Override - public String toString(){ - return "Err(Format(SealedSignature))"; - } - } - public static class EmptyKeys extends FormatError { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson() { - return FormatError.jsonWrapper(new JsonPrimitive("EmptyKeys")); - } - @Override - public String toString(){ - return "Err(Format(EmptyKeys))"; - } - } - public static class UnknownPublicKey extends FormatError { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson() { - return FormatError.jsonWrapper(new JsonPrimitive("UnknownPublicKey")); - } - @Override - public String toString(){ - return "Err(Format(UnknownPublicKey))"; - } - } - public static class DeserializationError extends FormatError { - final public String e; - - public DeserializationError(String e) { - this.e = e; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - DeserializationError other = (DeserializationError) o; - return e.equals(other.e); - } - - @Override - public int hashCode() { - return Objects.hash(e); - } - - @Override - public String toString(){ - return "Err(Format(DeserializationError(\""+this.e+"\"))"; - } - - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("DeserializationError", this.e); - return FormatError.jsonWrapper(jo); - } + public static final class InvalidFormat extends Signature { + public InvalidFormat() {} + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); } - public static class SerializationError extends FormatError { - final public String e; - - public SerializationError(String e) { - this.e = e; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SerializationError other = (SerializationError) o; - return e.equals(other.e); - } - - @Override - public int hashCode() { - return Objects.hash(e); - } - - @Override - public String toString(){ - return "Err(Format(SerializationError(\""+this.e+"\"))"; - } - - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("SerializationError", this.e); - return FormatError.jsonWrapper(jo); - } - } - public static class BlockDeserializationError extends FormatError { - final public String e; - - public BlockDeserializationError(String e) { - this.e = e; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BlockDeserializationError other = (BlockDeserializationError) o; - return e.equals(other.e); - } - - @Override - public int hashCode() { - return Objects.hash(e); - } - - @Override - public String toString() { - return "Err(FormatError.BlockDeserializationError{ error: "+ e + " }"; - } - - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("BlockDeserializationError", this.e); - return FormatError.jsonWrapper(jo); - } - } - public static class BlockSerializationError extends FormatError { - final public String e; - - public BlockSerializationError(String e) { - this.e = e; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BlockSerializationError other = (BlockSerializationError) o; - return e.equals(other.e); - } - - @Override - public int hashCode() { - return Objects.hash(e); - } - - @Override - public String toString() { - return "Err(FormatError.BlockSerializationError{ error: "+ e + " }"; - } - - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("BlockSerializationError", this.e); - return FormatError.jsonWrapper(jo); - } + @Override + public JsonElement toJson() { + return Signature.jsonWrapper(new JsonPrimitive("InvalidFormat")); } - public static class Version extends FormatError { - final public int minimum; - final public int maximum; - final public int actual; - - public Version(int minimum, int maximum, int actual) { - this.minimum = minimum; - this.maximum = maximum; - this.actual = actual; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Version version = (Version) o; - - if (minimum != version.minimum) return false; - if (maximum != version.maximum) return false; - return actual == version.actual; - } - - @Override - public int hashCode() { - return super.hashCode(); - } - - @Override - public String toString() { - return "Version{" + - "minimum=" + minimum + - ", maximum=" + maximum + - ", actual=" + actual + - '}'; - } - @Override - public JsonElement toJson() { - JsonObject child = new JsonObject(); - child.addProperty("minimum",this.minimum); - child.addProperty("maximum",this.maximum); - child.addProperty("actual", this.actual); - JsonObject jo = new JsonObject(); - jo.add("Version", child); - return FormatError.jsonWrapper(jo); - } + @Override + public String toString() { + return "Err(Format(Signature(InvalidFormat)))"; } + } - public static class InvalidSignatureSize extends FormatError { - final public int size; - - public InvalidSignatureSize(int size) { - this.size = size; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - InvalidSignatureSize iss = (InvalidSignatureSize) o; - - return size == iss.size; - } - - @Override - public int hashCode() { - return Objects.hash(size); - } - - @Override - public String toString() { - return "InvalidSignatureSize{" + - "size=" + size + - '}'; - } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.add("InvalidSignatureSize", new JsonPrimitive(size)); - return FormatError.jsonWrapper(jo); - } - } - } - public static class InvalidAuthorityIndex extends Error { - final public long index; + public static final class InvalidSignature extends Signature { + private final String err; - public InvalidAuthorityIndex(long index) { - this.index = index; + public InvalidSignature(String e) { + this.err = e; } @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidAuthorityIndex other = (InvalidAuthorityIndex) o; - return index == other.index; + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); } @Override - public int hashCode() { - return Objects.hash(index); + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("InvalidSignature", this.err); + return Signature.jsonWrapper(jo); } @Override public String toString() { - return "Err(InvalidAuthorityIndex{ index: "+ index + " }"; + return "Err(Format(Signature(InvalidFormat(\"" + this.err + "\"))))"; } + } + } - @Override - public JsonElement toJson() { - JsonObject child = new JsonObject(); - child.addProperty("index",this.index); - JsonObject jo = new JsonObject(); - jo.add("InvalidAuthorityIndex", child); - return jo; + public static final class SealedSignature extends FormatError { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + return o != null && getClass() == o.getClass(); + } + + @Override + public JsonElement toJson() { + return FormatError.jsonWrapper(new JsonPrimitive("SealedSignature")); + } + + @Override + public String toString() { + return "Err(Format(SealedSignature))"; + } } - public static class InvalidBlockIndex extends Error { - final public long expected; - final public long found; - public InvalidBlockIndex(long expected, long found) { - this.expected = expected; - this.found = found; + public static final class EmptyKeys extends FormatError { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + return o != null && getClass() == o.getClass(); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidBlockIndex other = (InvalidBlockIndex) o; - return expected == other.expected && found == other.found; - } + @Override + public JsonElement toJson() { + return FormatError.jsonWrapper(new JsonPrimitive("EmptyKeys")); + } - @Override - public int hashCode() { - return Objects.hash(expected, found); - } + @Override + public String toString() { + return "Err(Format(EmptyKeys))"; + } + } - @Override - public String toString() { - return "Err(InvalidBlockIndex{ expected: " + expected + ", found: " + found + " }"; + public static final class UnknownPublicKey extends FormatError { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + return o != null && getClass() == o.getClass(); + } - @Override - public JsonElement toJson() { - JsonObject child = new JsonObject(); - child.addProperty("expected",this.expected); - child.addProperty("fount", this.found); - JsonObject jo = new JsonObject(); - jo.add("InvalidBlockIndex", child); - return jo; - } + @Override + public JsonElement toJson() { + return FormatError.jsonWrapper(new JsonPrimitive("UnknownPublicKey")); + } + + @Override + public String toString() { + return "Err(Format(UnknownPublicKey))"; + } } - public static class SymbolTableOverlap extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("SymbolTableOverlap"); - } + public static final class DeserializationError extends FormatError { + private final String err; + + public DeserializationError(String e) { + this.err = e; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DeserializationError other = (DeserializationError) o; + return err.equals(other.err); + } + + @Override + public int hashCode() { + return Objects.hash(err); + } + + @Override + public String toString() { + return "Err(Format(DeserializationError(\"" + this.err + "\"))"; + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("DeserializationError", this.err); + return FormatError.jsonWrapper(jo); + } } - public static class MissingSymbols extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("MissingSymbols"); - } + + public static final class SerializationError extends FormatError { + private final String err; + + public SerializationError(String e) { + this.err = e; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SerializationError other = (SerializationError) o; + return err.equals(other.err); + } + + @Override + public int hashCode() { + return Objects.hash(err); + } + + @Override + public String toString() { + return "Err(Format(SerializationError(\"" + this.err + "\"))"; + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("SerializationError", this.err); + return FormatError.jsonWrapper(jo); + } } - public static class Sealed extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); + + public static final class BlockDeserializationError extends FormatError { + private final String err; + + public BlockDeserializationError(String e) { + this.err = e; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BlockDeserializationError other = (BlockDeserializationError) o; + return err.equals(other.err); + } + + @Override + public int hashCode() { + return Objects.hash(err); + } + + @Override + public String toString() { + return "Err(FormatError.BlockDeserializationError{ error: " + err + " }"; + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("BlockDeserializationError", this.err); + return FormatError.jsonWrapper(jo); + } + } + + public static final class BlockSerializationError extends FormatError { + private final String err; + + public BlockSerializationError(String e) { + this.err = e; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BlockSerializationError other = (BlockSerializationError) o; + return err.equals(other.err); + } + + @Override + public int hashCode() { + return Objects.hash(err); + } + + @Override + public String toString() { + return "Err(FormatError.BlockSerializationError{ error: " + err + " }"; + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("BlockSerializationError", this.err); + return FormatError.jsonWrapper(jo); + } + } + + public static final class Version extends FormatError { + private final int minimum; + private final int maximum; + private final int actual; + + public Version(int minimum, int maximum, int actual) { + this.minimum = minimum; + this.maximum = maximum; + this.actual = actual; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Version version = (Version) o; + + if (minimum != version.minimum) { + return false; + } + if (maximum != version.maximum) { + return false; + } + return actual == version.actual; + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public String toString() { + return "Version{" + + "minimum=" + + minimum + + ", maximum=" + + maximum + + ", actual=" + + actual + + '}'; + } + + @Override + public JsonElement toJson() { + JsonObject child = new JsonObject(); + child.addProperty("minimum", this.minimum); + child.addProperty("maximum", this.maximum); + child.addProperty("actual", this.actual); + JsonObject jo = new JsonObject(); + jo.add("Version", child); + return FormatError.jsonWrapper(jo); + } + } + + public static final class InvalidSignatureSize extends FormatError { + private final int size; + + public InvalidSignatureSize(int size) { + this.size = size; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("Sealed"); + if (o == null || getClass() != o.getClass()) { + return false; } + + InvalidSignatureSize iss = (InvalidSignatureSize) o; + + return size == iss.size; + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + @Override + public String toString() { + return "InvalidSignatureSize{" + "size=" + size + '}'; + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.add("InvalidSignatureSize", new JsonPrimitive(size)); + return FormatError.jsonWrapper(jo); + } } - public static class FailedLogic extends Error { - final public LogicError error; + } - public FailedLogic(LogicError error) { - this.error = error; - } + public static final class InvalidAuthorityIndex extends Error { + public final long index; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FailedLogic other = (FailedLogic) o; - return error.equals(other.error); - } + public InvalidAuthorityIndex(long index) { + this.index = index; + } - @Override - public int hashCode() { - return Objects.hash(error); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidAuthorityIndex other = (InvalidAuthorityIndex) o; + return index == other.index; + } - @Override - public String toString() { - return "Err(FailedLogic("+ error +"))"; - } + @Override + public int hashCode() { + return Objects.hash(index); + } - @Override - public Option> failed_checks() { - return this.error.failed_checks(); - } + @Override + public String toString() { + return "Err(InvalidAuthorityIndex{ index: " + index + " }"; + } - @Override - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.add("FailedLogic", this.error.toJson()); - return jo; - } + @Override + public JsonElement toJson() { + JsonObject child = new JsonObject(); + child.addProperty("index", this.index); + JsonObject jo = new JsonObject(); + jo.add("InvalidAuthorityIndex", child); + return jo; + } + } + + public static final class InvalidBlockIndex extends Error { + public final long expected; + public final long found; + public InvalidBlockIndex(long expected, long found) { + this.expected = expected; + this.found = found; } - public static class Language extends Error { - final public FailedCheck.LanguageError langError; - public Language(FailedCheck.LanguageError langError){ - this.langError = langError; - } - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidBlockIndex other = (InvalidBlockIndex) o; + return expected == other.expected && found == other.found; + } - @Override - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.add("Language", langError.toJson()); - return jo; - } + @Override + public int hashCode() { + return Objects.hash(expected, found); } - public static class TooManyFacts extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("TooManyFacts"); - } + @Override + public String toString() { + return "Err(InvalidBlockIndex{ expected: " + expected + ", found: " + found + " }"; } - public static class TooManyIterations extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("TooManyIterations"); - } + @Override + public JsonElement toJson() { + JsonObject child = new JsonObject(); + child.addProperty("expected", this.expected); + child.addProperty("fount", this.found); + JsonObject jo = new JsonObject(); + jo.add("InvalidBlockIndex", child); + return jo; + } + } + + public static final class SymbolTableOverlap extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); } - public static class Timeout extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("Timeout"); - } + @Override + public JsonElement toJson() { + return new JsonPrimitive("SymbolTableOverlap"); + } + } + + public static final class MissingSymbols extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); } - public static class Execution extends Error { - public enum Kind { - Execution, - Overflow + @Override + public JsonElement toJson() { + return new JsonPrimitive("MissingSymbols"); + } + } + + public static final class Sealed extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } - } - Expression e; - String message; + @Override + public JsonElement toJson() { + return new JsonPrimitive("Sealed"); + } + } - Kind kind; + public static final class FailedLogic extends Error { + public final LogicError error; - public Execution(Expression ex, String msg) { - e = ex; - message = msg; - kind = Kind.Execution; - } + public FailedLogic(LogicError error) { + this.error = error; + } - public Execution( String msg) { - e = null; - message = msg; - kind = Kind.Execution; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FailedLogic other = (FailedLogic) o; + return error.equals(other.error); + } - public Execution(Kind kind, String msg) { - e = null; - this.kind = kind; - message = msg; - } + @Override + public int hashCode() { + return Objects.hash(error); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.add("Execution", new JsonPrimitive(this.kind.toString())); - return jo; + @Override + public String toString() { + return "Err(FailedLogic(" + error + "))"; + } - } + @Override + public Option> getFailedChecks() { + return this.error.getFailedChecks(); + } - @Override - public String toString() { - return "Execution error when evaluating expression '" + e + - "': " + message; - } + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.add("FailedLogic", this.error.toJson()); + return jo; } + } - public static class InvalidType extends Error { - @Override - public boolean equals(Object o) { - if (this == o) return true; - return o != null && getClass() == o.getClass(); - } - @Override - public JsonElement toJson(){ - return new JsonPrimitive("InvalidType"); - } + public static final class Language extends Error { + public final FailedCheck.LanguageError langError; + + public Language(FailedCheck.LanguageError langError) { + this.langError = langError; } - public static class Parser extends Error { - final public org.biscuitsec.biscuit.token.builder.parser.Error error; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } - public Parser(org.biscuitsec.biscuit.token.builder.parser.Error error) { - this.error = error; - } + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.add("Language", langError.toJson()); + return jo; + } + } + + public static final class TooManyFacts extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public JsonElement toJson() { + return new JsonPrimitive("TooManyFacts"); + } + } + + public static final class TooManyIterations extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } - Parser parser = (Parser) o; + @Override + public JsonElement toJson() { + return new JsonPrimitive("TooManyIterations"); + } + } + + public static final class Timeout extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } - return error.equals(parser.error); - } + @Override + public JsonElement toJson() { + return new JsonPrimitive("Timeout"); + } + } - @Override - public int hashCode() { - return error.hashCode(); - } + public static final class Execution extends Error { + public enum Kind { + Execution, + Overflow + } - @Override - public String toString() { - return "Parser{" + - "error=" + error + - '}'; - } + Expression expr; + String message; - @Override - public JsonElement toJson(){ - JsonObject error = new JsonObject(); - error.add("error", this.error.toJson()); - return error; - } + Kind kind; + + public Execution(Expression ex, String msg) { + expr = ex; + message = msg; + kind = Kind.Execution; + } + + public Execution(String msg) { + expr = null; + message = msg; + kind = Kind.Execution; + } + + public Execution(Kind kind, String msg) { + expr = null; + this.kind = kind; + message = msg; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } + + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.add("Execution", new JsonPrimitive(this.kind.toString())); + return jo; + } + + @Override + public String toString() { + return "Execution error when evaluating expression '" + expr + "': " + message; + } + } + + public static final class InvalidType extends Error { + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o != null && getClass() == o.getClass(); + } + + @Override + public JsonElement toJson() { + return new JsonPrimitive("InvalidType"); + } + } + + public static final class Parser extends Error { + public final org.biscuitsec.biscuit.token.builder.parser.Error error; + + public Parser(org.biscuitsec.biscuit.token.builder.parser.Error error) { + this.error = error; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Parser parser = (Parser) o; + + return error.equals(parser.error); + } + + @Override + public int hashCode() { + return error.hashCode(); + } + + @Override + public String toString() { + return "Parser{error=" + error + '}'; + } + + @Override + public JsonElement toJson() { + JsonObject error = new JsonObject(); + error.add("error", this.error.toJson()); + return error; } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/error/FailedCheck.java b/src/main/java/org/biscuitsec/biscuit/error/FailedCheck.java index 8aa17133..ec4156de 100644 --- a/src/main/java/org/biscuitsec/biscuit/error/FailedCheck.java +++ b/src/main/java/org/biscuitsec/biscuit/error/FailedCheck.java @@ -1,172 +1,199 @@ package org.biscuitsec.biscuit.error; -import com.google.gson.*; - +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; import java.util.List; import java.util.Objects; public class FailedCheck { - public JsonElement toJson(){ return new JsonObject();} - - public static class FailedBlock extends FailedCheck { - final public long block_id; - final public long check_id; - final public String rule; - - public FailedBlock(long block_id, long check_id, String rule) { - this.block_id = block_id; - this.check_id = check_id; - this.rule = rule; - } + /** + * serialize to Json Object + * + * @return json object + */ + public JsonElement toJson() { + return new JsonObject(); + } + + public static final class FailedBlock extends FailedCheck { + public final long blockId; + public final long checkId; + public final String rule; + + public FailedBlock(long blockId, long checkId, String rule) { + this.blockId = blockId; + this.checkId = checkId; + this.rule = rule; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FailedBlock b = (FailedBlock) o; - return block_id == b.block_id && check_id == b.check_id && rule.equals(b.rule); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FailedBlock b = (FailedBlock) o; + return blockId == b.blockId && checkId == b.checkId && rule.equals(b.rule); + } - @Override - public int hashCode() { - return Objects.hash(block_id, check_id, rule); - } + @Override + public int hashCode() { + return Objects.hash(blockId, checkId, rule); + } - @Override - public String toString() { - return "Block(FailedBlockCheck " + new Gson().toJson(toJson())+")"; - } + @Override + public String toString() { + return "Block(FailedBlockCheck " + new Gson().toJson(toJson()) + ")"; + } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("block_id", block_id); - jo.addProperty("check_id", check_id); - jo.addProperty("rule", rule); - JsonObject block = new JsonObject(); - block.add("Block", jo); - return block; - } + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("block_id", blockId); + jo.addProperty("check_id", checkId); + jo.addProperty("rule", rule); + JsonObject block = new JsonObject(); + block.add("Block", jo); + return block; } + } - public static class FailedAuthorizer extends FailedCheck { - final public long check_id; - final public String rule; + public static final class FailedAuthorizer extends FailedCheck { + public final long checkId; + public final String rule; - public FailedAuthorizer(long check_id, String rule) { - this.check_id = check_id; - this.rule = rule; - } + public FailedAuthorizer(long checkId, String rule) { + this.checkId = checkId; + this.rule = rule; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - FailedAuthorizer b = (FailedAuthorizer) o; - return check_id == b.check_id && rule.equals(b.rule); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FailedAuthorizer b = (FailedAuthorizer) o; + return checkId == b.checkId && rule.equals(b.rule); + } - @Override - public int hashCode() { - return Objects.hash(check_id, rule); - } + @Override + public int hashCode() { + return Objects.hash(checkId, rule); + } - @Override - public String toString() { - return "FailedCaveat.FailedAuthorizer { check_id: "+check_id+ - ", rule: "+rule+" }"; - } + @Override + public String toString() { + return "FailedCaveat.FailedAuthorizer { check_id: " + checkId + ", rule: " + rule + " }"; + } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - jo.addProperty("check_id", check_id); - jo.addProperty("rule", rule); - JsonObject authorizer = new JsonObject(); - authorizer.add("Authorizer", jo); - return authorizer; - } + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("check_id", checkId); + jo.addProperty("rule", rule); + JsonObject authorizer = new JsonObject(); + authorizer.add("Authorizer", jo); + return authorizer; } + } - public static class ParseErrors extends FailedCheck { + public static final class ParseErrors extends FailedCheck {} + public static class LanguageError extends FailedCheck { + public static final class ParseError extends LanguageError { + + @Override + public JsonElement toJson() { + return new JsonPrimitive("ParseError"); + } } - public static class LanguageError extends FailedCheck { - public static class ParseError extends LanguageError { + public static final class Builder extends LanguageError { + List invalidVariables; + + public Builder(List invalidVariables) { + this.invalidVariables = invalidVariables; + } - @Override - public JsonElement toJson() { - return new JsonPrimitive("ParseError"); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - public static class Builder extends LanguageError { - List invalid_variables; - public Builder(List invalid_variables) { - this.invalid_variables = invalid_variables; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Builder b = (Builder) o; - return invalid_variables == b.invalid_variables && invalid_variables.equals(b.invalid_variables); - } - - @Override - public int hashCode() { - return Objects.hash(invalid_variables); - } - - @Override - public String toString() { - return "InvalidVariables { message: "+invalid_variables+" }"; - } - - @Override - public JsonElement toJson() { - JsonObject authorizer = new JsonObject(); - JsonArray ja = new JsonArray(); - for(String s : invalid_variables){ - ja.add(s); - } - authorizer.add("InvalidVariables", ja); - return authorizer; - } + if (o == null || getClass() != o.getClass()) { + return false; } + Builder b = (Builder) o; + return invalidVariables == b.invalidVariables + && invalidVariables.equals(b.invalidVariables); + } + + @Override + public int hashCode() { + return Objects.hash(invalidVariables); + } + + @Override + public String toString() { + return "InvalidVariables { message: " + invalidVariables + " }"; + } + + @Override + public JsonElement toJson() { + JsonObject authorizer = new JsonObject(); + JsonArray ja = new JsonArray(); + for (String s : invalidVariables) { + ja.add(s); + } + authorizer.add("InvalidVariables", ja); + return authorizer; + } + } - public static class UnknownVariable extends LanguageError { - String message; - public UnknownVariable(String message) { - this.message = message; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - UnknownVariable b = (UnknownVariable) o; - return this.message == b.message && message.equals(b.message); - } - - @Override - public int hashCode() { - return Objects.hash(message); - } - - @Override - public String toString() { - return "LanguageError.UnknownVariable { message: "+message+ " }"; - } - - @Override - public JsonElement toJson() { - JsonObject authorizer = new JsonObject(); - authorizer.add("UnknownVariable", new JsonPrimitive(message)); - return authorizer; - } + public static final class UnknownVariable extends LanguageError { + String message; + + public UnknownVariable(String message) { + this.message = message; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; } + UnknownVariable b = (UnknownVariable) o; + return this.message == b.message && message.equals(b.message); + } + + @Override + public int hashCode() { + return Objects.hash(message); + } + + @Override + public String toString() { + return "LanguageError.UnknownVariable { message: " + message + " }"; + } + + @Override + public JsonElement toJson() { + JsonObject authorizer = new JsonObject(); + authorizer.add("UnknownVariable", new JsonPrimitive(message)); + return authorizer; + } } -} \ No newline at end of file + } +} diff --git a/src/main/java/org/biscuitsec/biscuit/error/LogicError.java b/src/main/java/org/biscuitsec/biscuit/error/LogicError.java index 6aba1978..725bd39d 100644 --- a/src/main/java/org/biscuitsec/biscuit/error/LogicError.java +++ b/src/main/java/org/biscuitsec/biscuit/error/LogicError.java @@ -1,329 +1,352 @@ package org.biscuitsec.biscuit.error; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; +import io.vavr.control.Option; import java.util.List; import java.util.Objects; -import com.google.gson.*; -import io.vavr.control.Option; - public class LogicError { - public Option> failed_checks() { - return Option.none(); - } - public JsonElement toJson() { - return new JsonObject(); - } + public Option> getFailedChecks() { + return Option.none(); + } - public static class InvalidAuthorityFact extends LogicError { - final public String e; + public JsonElement toJson() { + return new JsonObject(); + } - public InvalidAuthorityFact(String e) { - this.e = e; - } + public static final class InvalidAuthorityFact extends LogicError { + public final String err; - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidAuthorityFact other = (InvalidAuthorityFact) o; - return e.equals(other.e); - } + public InvalidAuthorityFact(String e) { + this.err = e; + } - @Override - public int hashCode() { - return Objects.hash(e); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidAuthorityFact other = (InvalidAuthorityFact) o; + return err.equals(other.err); + } - @Override - public String toString() { - return "LogicError.InvalidAuthorityFact{ error: "+ e + " }"; - } + @Override + public int hashCode() { + return Objects.hash(err); + } - @Override - public JsonElement toJson() { - return new JsonPrimitive("InvalidAuthorityFact"); - } + @Override + public String toString() { + return "LogicError.InvalidAuthorityFact{ error: " + err + " }"; + } + @Override + public JsonElement toJson() { + return new JsonPrimitive("InvalidAuthorityFact"); } + } - public static class InvalidAmbientFact extends LogicError { - final public String e; + public static final class InvalidAmbientFact extends LogicError { + public final String err; - public InvalidAmbientFact(String e) { - this.e = e; - } + public InvalidAmbientFact(String e) { + this.err = e; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidAmbientFact other = (InvalidAmbientFact) o; - return e.equals(other.e); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidAmbientFact other = (InvalidAmbientFact) o; + return err.equals(other.err); + } - @Override - public int hashCode() { - return Objects.hash(e); - } + @Override + public int hashCode() { + return Objects.hash(err); + } - @Override - public String toString() { - return "LogicError.InvalidAmbientFact{ error: "+ e + " }"; - } + @Override + public String toString() { + return "LogicError.InvalidAmbientFact{ error: " + err + " }"; + } - @Override - public JsonElement toJson() { - JsonObject child = new JsonObject(); - child.addProperty("error", this.e); - JsonObject root = new JsonObject(); - root.add("InvalidAmbientFact", child); - return root; - } + @Override + public JsonElement toJson() { + JsonObject child = new JsonObject(); + child.addProperty("error", this.err); + JsonObject root = new JsonObject(); + root.add("InvalidAmbientFact", child); + return root; } + } - public static class InvalidBlockFact extends LogicError { - final public long id; - final public String e; + public static final class InvalidBlockFact extends LogicError { + public final long id; + public final String err; - public InvalidBlockFact(long id, String e) { - this.id = id; - this.e = e; - } + public InvalidBlockFact(long id, String e) { + this.id = id; + this.err = e; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidBlockFact other = (InvalidBlockFact) o; - return id == other.id && e.equals(other.e); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidBlockFact other = (InvalidBlockFact) o; + return id == other.id && err.equals(other.err); + } - @Override - public int hashCode() { - return Objects.hash(id, e); - } + @Override + public int hashCode() { + return Objects.hash(id, err); + } - @Override - public String toString() { - return "LogicError.InvalidBlockFact{ id: "+id+", error: "+ e + " }"; - } + @Override + public String toString() { + return "LogicError.InvalidBlockFact{ id: " + id + ", error: " + err + " }"; + } - @Override - public JsonElement toJson() { - JsonObject child = new JsonObject(); - child.addProperty("id",this.id); - child.addProperty("error", this.e); - JsonObject root = new JsonObject(); - root.add("InvalidBlockFact", child); - return root; - } + @Override + public JsonElement toJson() { + JsonObject child = new JsonObject(); + child.addProperty("id", this.id); + child.addProperty("error", this.err); + JsonObject root = new JsonObject(); + root.add("InvalidBlockFact", child); + return root; + } + } + public static final class InvalidBlockRule extends LogicError { + public final long id; + public final String err; + public InvalidBlockRule(long id, String e) { + this.id = id; + this.err = e; } - public static class InvalidBlockRule extends LogicError { - final public long id; - final public String e; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InvalidBlockRule other = (InvalidBlockRule) o; + return id == other.id && err.equals(other.err); + } - public InvalidBlockRule(long id, String e) { - this.id = id; - this.e = e; - } + @Override + public int hashCode() { + return Objects.hash(id, err); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InvalidBlockRule other = (InvalidBlockRule) o; - return id == other.id && e.equals(other.e); - } + @Override + public String toString() { + return "LogicError.InvalidBlockRule{ id: " + id + ", error: " + err + " }"; + } - @Override - public int hashCode() { - return Objects.hash(id, e); - } + @Override + public JsonElement toJson() { + JsonArray child = new JsonArray(); + child.add(this.id); + child.add(this.err); + JsonObject root = new JsonObject(); + root.add("InvalidBlockRule", child); + return root; + } + } - @Override - public String toString() { - return "LogicError.InvalidBlockRule{ id: "+id+", error: "+ e + " }"; - } + public static final class Unauthorized extends LogicError { + public final List errors; + public final MatchedPolicy policy; - @Override - public JsonElement toJson() { - JsonArray child = new JsonArray(); - child.add(this.id); - child.add(this.e); - JsonObject root = new JsonObject(); - root.add("InvalidBlockRule", child); - return root; - } + public Unauthorized(MatchedPolicy policy, List errors) { + this.errors = errors; + this.policy = policy; + } + public Option> getFailedChecks() { + return Option.some(errors); + } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Unauthorized other = (Unauthorized) o; + if (errors.size() != other.errors.size()) { + return false; + } + for (int i = 0; i < errors.size(); i++) { + if (!errors.get(i).equals(other.errors.get(i))) { + return false; + } + } + return true; } - public static class Unauthorized extends LogicError { - final public List errors; - final public MatchedPolicy policy; + @Override + public int hashCode() { + return Objects.hash(errors); + } - public Unauthorized(MatchedPolicy policy, List errors) { - this.errors = errors; - this.policy = policy; - } + @Override + public String toString() { + return "Unauthorized(policy = " + policy + " errors = " + errors + ")"; + } - public Option> failed_checks() { - return Option.some(errors); - } + @Override + public JsonElement toJson() { + JsonObject unauthorized = new JsonObject(); + unauthorized.add("policy", this.policy.toJson()); + JsonArray ja = new JsonArray(); + for (FailedCheck t : this.errors) { + ja.add(t.toJson()); + } + unauthorized.add("checks", ja); + JsonObject jo = new JsonObject(); + jo.add("Unauthorized", unauthorized); + return jo; + } + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Unauthorized other = (Unauthorized) o; - if(errors.size() != other.errors.size()) { - return false; - } - for(int i = 0; i < errors.size(); i++) { - if(!errors.get(i).equals(other.errors.get(i))) { - return false; - } - } - return true; - } + public static final class NoMatchingPolicy extends LogicError { + public final List errors; - @Override - public int hashCode() { - return Objects.hash(errors); - } + public NoMatchingPolicy(List errors) { + this.errors = errors; + } - @Override - public String toString() { - return "Unauthorized( policy = "+policy+ " errors = " + errors +")"; - } + @Override + public int hashCode() { + return Objects.hash(errors); + } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - JsonObject unauthorized = new JsonObject(); - unauthorized.add("policy", this.policy.toJson()); - JsonArray ja = new JsonArray(); - for (FailedCheck t: this.errors) { - ja.add(t.toJson()); - } - unauthorized.add("checks", ja); - jo.add("Unauthorized", unauthorized); - return jo; - } + @Override + public Option> getFailedChecks() { + return Option.some(errors); } - public static class NoMatchingPolicy extends LogicError { - final public List errors; - public NoMatchingPolicy(List errors) { - this.errors = errors; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Unauthorized other = (Unauthorized) o; + if (errors.size() != other.errors.size()) { + return false; + } + for (int i = 0; i < errors.size(); i++) { + if (!errors.get(i).equals(other.errors.get(i))) { + return false; + } + } + return true; + } - @Override - public int hashCode() { - return Objects.hash(errors); - } + @SuppressWarnings("checkstyle:RegexpSinglelineJava") + @Override + public String toString() { + return "NoMatchingPolicy{ }"; + } - @Override - public Option> failed_checks() { - return Option.some(errors); - } + @Override + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + JsonArray ja = new JsonArray(); + for (FailedCheck t : this.errors) { + ja.add(t.toJson()); + } + jo.add("NoMatchingPolicy", ja); + return jo; + } + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Unauthorized other = (Unauthorized) o; - if(errors.size() != other.errors.size()) { - return false; - } - for(int i = 0; i < errors.size(); i++) { - if(!errors.get(i).equals(other.errors.get(i))) { - return false; - } - } - return true; - } + public static final class AuthorizerNotEmpty extends LogicError { + public AuthorizerNotEmpty() {} - @Override - public String toString() { - return "NoMatchingPolicy{}"; - } + @Override + public int hashCode() { + return super.hashCode(); + } - @Override - public JsonElement toJson() { - JsonObject jo = new JsonObject(); - JsonArray ja = new JsonArray(); - for (FailedCheck t: this.errors) { - ja.add(t.toJson()); - } - jo.add("NoMatchingPolicy", ja); - return jo; - } + @Override + public boolean equals(Object obj) { + return super.equals(obj); } - public static class AuthorizerNotEmpty extends LogicError { + @Override + public String toString() { + return "AuthorizerNotEmpty"; + } + } - public AuthorizerNotEmpty() { + public abstract static class MatchedPolicy { + public abstract JsonElement toJson(); - } + public static final class Allow extends MatchedPolicy { + public final long nb; - @Override - public int hashCode() { - return super.hashCode(); - } + public Allow(long nb) { + this.nb = nb; + } - @Override - public boolean equals(Object obj) { - return super.equals(obj); - } + @Override + public String toString() { + return "Allow(" + this.nb + ")"; + } - @Override - public String toString() { - return "AuthorizerNotEmpty"; - } + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("Allow", this.nb); + return jo; + } } - public static abstract class MatchedPolicy { - public abstract JsonElement toJson(); + public static final class Deny extends MatchedPolicy { + public final long nb; - public static class Allow extends MatchedPolicy { - final public long nb; - public Allow(long nb){ - this.nb = nb; - } + public Deny(long nb) { + this.nb = nb; + } - @Override - public String toString(){ - return "Allow("+this.nb+")"; - } + @Override + public String toString() { + return "Deny(" + this.nb + ")"; + } - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.addProperty("Allow", this.nb); - return jo; - } - } - - public static class Deny extends MatchedPolicy { - public final long nb; - public Deny(long nb){ - this.nb = nb; - } - - @Override - public String toString(){ - return "Deny("+this.nb+")"; - } - - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.addProperty("Deny", this.nb); - return jo; - } - } + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("Deny", this.nb); + return jo; + } } -} \ No newline at end of file + } +} diff --git a/src/main/java/org/biscuitsec/biscuit/error/package-info.java b/src/main/java/org/biscuitsec/biscuit/error/package-info.java index 29a9773d..0ffe2dbd 100644 --- a/src/main/java/org/biscuitsec/biscuit/error/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/error/package-info.java @@ -1,4 +1,2 @@ -/** - * Error description classes - */ -package org.biscuitsec.biscuit.error; \ No newline at end of file +/** Error description classes */ +package org.biscuitsec.biscuit.error; diff --git a/src/main/java/org/biscuitsec/biscuit/token/Authorizer.java b/src/main/java/org/biscuitsec/biscuit/token/Authorizer.java index 7b6c7f6e..d9892702 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Authorizer.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Authorizer.java @@ -1,678 +1,719 @@ package org.biscuitsec.biscuit.token; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.*; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.error.FailedCheck; -import org.biscuitsec.biscuit.error.LogicError; -import org.biscuitsec.biscuit.token.builder.*; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import io.vavr.Tuple2; import io.vavr.control.Either; import io.vavr.control.Option; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.FactSet; +import org.biscuitsec.biscuit.datalog.Origin; +import org.biscuitsec.biscuit.datalog.RuleSet; +import org.biscuitsec.biscuit.datalog.RunLimits; import org.biscuitsec.biscuit.datalog.Scope; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.datalog.TrustedOrigins; +import org.biscuitsec.biscuit.datalog.World; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.error.FailedCheck; +import org.biscuitsec.biscuit.error.LogicError; import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; import org.biscuitsec.biscuit.token.builder.Term; +import org.biscuitsec.biscuit.token.builder.Utils; import org.biscuitsec.biscuit.token.builder.parser.Parser; -import java.time.Instant; -import java.util.*; -import java.util.stream.Collectors; +/** Token verification class */ +public final class Authorizer { + private Biscuit token; + private final List checks; + private final List policies; + private final List scopes; + private final HashMap> publicKeyToBlockId; + private final World world; + private final SymbolTable symbolTable; + + private Authorizer(Biscuit token, World w) throws Error.FailedLogic { + this.token = token; + this.world = w; + this.symbolTable = new SymbolTable(this.token.symbolTable); + this.checks = new ArrayList<>(); + this.policies = new ArrayList<>(); + this.scopes = new ArrayList<>(); + this.publicKeyToBlockId = new HashMap<>(); + updateOnToken(); + } + + /** + * Creates an empty authorizer + * + *

used to apply policies when unauthenticated (no token) and to preload an authorizer that is + * cloned for each new request + */ + public Authorizer() { + this.world = new World(); + this.symbolTable = Biscuit.defaultSymbolTable(); + this.checks = new ArrayList<>(); + this.policies = new ArrayList<>(); + this.scopes = new ArrayList<>(); + this.publicKeyToBlockId = new HashMap<>(); + } + + private Authorizer( + Biscuit token, + List checks, + List policies, + World world, + SymbolTable symbolTable) { + this.token = token; + this.checks = checks; + this.policies = policies; + this.world = world; + this.symbolTable = symbolTable; + this.scopes = new ArrayList<>(); + this.publicKeyToBlockId = new HashMap<>(); + } + + /** + * Creates a authorizer for a token + * + *

also checks that the token is valid for this root public key + * + * @param token + * @return Authorizer + */ + public static Authorizer make(Biscuit token) throws Error.FailedLogic { + return new Authorizer(token, new World()); + } + + public Authorizer clone() { + return new Authorizer( + this.token, + new ArrayList<>(this.checks), + new ArrayList<>(this.policies), + new World(this.world), + new SymbolTable(this.symbolTable)); + } + + public void updateOnToken() throws Error.FailedLogic { + if (token != null) { + for (long i = 0; i < token.blocks.size(); i++) { + Block block = token.blocks.get((int) i); + + if (block.getExternalKey().isDefined()) { + PublicKey pk = block.getExternalKey().get(); + long newKeyId = this.symbolTable.insert(pk); + if (!this.publicKeyToBlockId.containsKey(newKeyId)) { + List l = new ArrayList<>(); + l.add(i + 1); + this.publicKeyToBlockId.put(newKeyId, l); + } else { + this.publicKeyToBlockId.get(newKeyId).add(i + 1); + } + } + } + + TrustedOrigins authorityTrustedOrigins = + TrustedOrigins.fromScopes( + token.authority.getScopes(), + TrustedOrigins.defaultOrigins(), + 0, + this.publicKeyToBlockId); + + for (org.biscuitsec.biscuit.datalog.Fact fact : token.authority.getFacts()) { + org.biscuitsec.biscuit.datalog.Fact convertedFact = + org.biscuitsec.biscuit.token.builder.Fact.convertFrom(fact, token.symbolTable) + .convert(this.symbolTable); + world.addFact(new Origin(0), convertedFact); + } + for (org.biscuitsec.biscuit.datalog.Rule rule : token.authority.getRules()) { + org.biscuitsec.biscuit.token.builder.Rule locRule = + org.biscuitsec.biscuit.token.builder.Rule.convertFrom(rule, token.symbolTable); + org.biscuitsec.biscuit.datalog.Rule convertedRule = locRule.convert(this.symbolTable); + + Either res = locRule.validateVariables(); + if (res.isLeft()) { + throw new Error.FailedLogic( + new LogicError.InvalidBlockRule(0, token.symbolTable.formatRule(convertedRule))); + } + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + convertedRule.scopes(), authorityTrustedOrigins, 0, this.publicKeyToBlockId); + world.addRule((long) 0, ruleTrustedOrigins, convertedRule); + } + + for (long i = 0; i < token.blocks.size(); i++) { + Block block = token.blocks.get((int) i); + TrustedOrigins blockTrustedOrigins = + TrustedOrigins.fromScopes( + block.getScopes(), TrustedOrigins.defaultOrigins(), i + 1, this.publicKeyToBlockId); + + SymbolTable blockSymbolTable = token.symbolTable; + + if (block.getExternalKey().isDefined()) { + blockSymbolTable = new SymbolTable(block.getSymbolTable(), block.getPublicKeys()); + } -import static io.vavr.API.Left; -import static io.vavr.API.Right; + for (org.biscuitsec.biscuit.datalog.Fact fact : block.getFacts()) { + org.biscuitsec.biscuit.datalog.Fact convertedFact = + org.biscuitsec.biscuit.token.builder.Fact.convertFrom(fact, blockSymbolTable) + .convert(this.symbolTable); + world.addFact(new Origin(i + 1), convertedFact); + } -/** - * Token verification class - */ -public class Authorizer { - Biscuit token; - List checks; - List policies; - List scopes; - HashMap> publicKeyToBlockId; - World world; - SymbolTable symbols; - - private Authorizer(Biscuit token, World w) throws Error.FailedLogic { - this.token = token; - this.world = w; - this.symbols = new SymbolTable(this.token.symbols); - this.checks = new ArrayList<>(); - this.policies = new ArrayList<>(); - this.scopes = new ArrayList<>(); - this.publicKeyToBlockId = new HashMap<>(); - update_on_token(); + for (org.biscuitsec.biscuit.datalog.Rule rule : block.getRules()) { + org.biscuitsec.biscuit.token.builder.Rule syRole = + org.biscuitsec.biscuit.token.builder.Rule.convertFrom(rule, blockSymbolTable); + org.biscuitsec.biscuit.datalog.Rule convertedRule = syRole.convert(this.symbolTable); + + Either res = + syRole.validateVariables(); + if (res.isLeft()) { + throw new Error.FailedLogic( + new LogicError.InvalidBlockRule(0, this.symbolTable.formatRule(convertedRule))); + } + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + convertedRule.scopes(), blockTrustedOrigins, i + 1, this.publicKeyToBlockId); + world.addRule((long) i + 1, ruleTrustedOrigins, convertedRule); + } + } } + } - /** - * Creates an empty authorizer - *

- * used to apply policies when unauthenticated (no token) - * and to preload an authorizer that is cloned for each new request - */ - public Authorizer() { - this.world = new World(); - this.symbols = Biscuit.default_symbol_table(); - this.checks = new ArrayList<>(); - this.policies = new ArrayList<>(); - this.scopes = new ArrayList<>(); - this.publicKeyToBlockId = new HashMap<>(); + public Authorizer addToken(Biscuit token) throws Error.FailedLogic { + if (this.token != null) { + throw new Error.FailedLogic(new LogicError.AuthorizerNotEmpty()); } - private Authorizer(Biscuit token, List checks, List policies, - World world, SymbolTable symbols) { - this.token = token; - this.checks = checks; - this.policies = policies; - this.world = world; - this.symbols = symbols; - this.scopes = new ArrayList<>(); - this.publicKeyToBlockId = new HashMap<>(); - } + this.token = token; + updateOnToken(); + return this; + } - /** - * Creates a authorizer for a token - *

- * also checks that the token is valid for this root public key - * - * @param token - * @return Authorizer - */ - static public Authorizer make(Biscuit token) throws Error.FailedLogic { - return new Authorizer(token, new World()); - } + public Authorizer addFact(org.biscuitsec.biscuit.token.builder.Fact fact) { + world.addFact(Origin.authorizer(), fact.convert(symbolTable)); + return this; + } - public Authorizer clone() { - return new Authorizer(this.token, new ArrayList<>(this.checks), new ArrayList<>(this.policies), - new World(this.world), new SymbolTable(this.symbols)); + public Authorizer addFact(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.fact(s); + + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public void update_on_token() throws Error.FailedLogic { - if (token != null) { - for(long i =0; i < token.blocks.size(); i++) { - Block block = token.blocks.get((int) i); - - if (block.externalKey.isDefined()) { - PublicKey pk = block.externalKey.get(); - long newKeyId = this.symbols.insert(pk); - if (!this.publicKeyToBlockId.containsKey(newKeyId)) { - List l = new ArrayList<>(); - l.add(i + 1); - this.publicKeyToBlockId.put(newKeyId, l); - } else { - this.publicKeyToBlockId.get(newKeyId).add(i + 1); - } - } - } + Tuple2 t = res.get(); - TrustedOrigins authorityTrustedOrigins = TrustedOrigins.fromScopes( - token.authority.scopes, - TrustedOrigins.defaultOrigins(), - 0, - this.publicKeyToBlockId - ); + return this.addFact(t._2); + } - for (org.biscuitsec.biscuit.datalog.Fact fact : token.authority.facts) { - org.biscuitsec.biscuit.datalog.Fact converted_fact = org.biscuitsec.biscuit.token.builder.Fact.convert_from(fact, token.symbols).convert(this.symbols); - world.add_fact(new Origin(0), converted_fact); - } - for (org.biscuitsec.biscuit.datalog.Rule rule : token.authority.rules) { - org.biscuitsec.biscuit.token.builder.Rule _rule = org.biscuitsec.biscuit.token.builder.Rule.convert_from(rule, token.symbols); - org.biscuitsec.biscuit.datalog.Rule converted_rule = _rule.convert(this.symbols); - - Either res = _rule.validate_variables(); - if(res.isLeft()){ - throw new Error.FailedLogic(new LogicError.InvalidBlockRule(0, token.symbols.print_rule(converted_rule))); - } - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - converted_rule.scopes(), - authorityTrustedOrigins, - 0, - this.publicKeyToBlockId - ); - world.add_rule((long) 0, ruleTrustedOrigins, converted_rule); - } + public Authorizer addRule(org.biscuitsec.biscuit.token.builder.Rule rule) { + org.biscuitsec.biscuit.datalog.Rule r = rule.convert(symbolTable); + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + r.scopes(), this.authorizerTrustedOrigins(), Long.MAX_VALUE, this.publicKeyToBlockId); + world.addRule(Long.MAX_VALUE, ruleTrustedOrigins, r); + return this; + } - for(long i =0; i < token.blocks.size(); i++) { - Block block = token.blocks.get((int)i); - TrustedOrigins blockTrustedOrigins = TrustedOrigins.fromScopes( - block.scopes, - TrustedOrigins.defaultOrigins(), - i + 1, - this.publicKeyToBlockId - ); - - SymbolTable blockSymbols = token.symbols; - - if(block.externalKey.isDefined()) { - blockSymbols = new SymbolTable(block.symbols.symbols, block.publicKeys()); - } - - for (org.biscuitsec.biscuit.datalog.Fact fact : block.facts) { - org.biscuitsec.biscuit.datalog.Fact converted_fact = org.biscuitsec.biscuit.token.builder.Fact.convert_from(fact, blockSymbols).convert(this.symbols); - world.add_fact(new Origin(i + 1), converted_fact); - } - - for (org.biscuitsec.biscuit.datalog.Rule rule : block.rules) { - org.biscuitsec.biscuit.token.builder.Rule _rule = org.biscuitsec.biscuit.token.builder.Rule.convert_from(rule, blockSymbols); - org.biscuitsec.biscuit.datalog.Rule converted_rule = _rule.convert(this.symbols); - - Either res = _rule.validate_variables(); - if (res.isLeft()) { - throw new Error.FailedLogic(new LogicError.InvalidBlockRule(0, this.symbols.print_rule(converted_rule))); - } - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - converted_rule.scopes(), - blockTrustedOrigins, - i + 1, - this.publicKeyToBlockId - ); - world.add_rule((long) i + 1, ruleTrustedOrigins, converted_rule); - } - } - } - } + public Authorizer addRule(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.rule(s); - public Authorizer add_token(Biscuit token) throws Error.FailedLogic { - if (this.token != null) { - throw new Error.FailedLogic(new LogicError.AuthorizerNotEmpty()); - } - - this.token = token; - update_on_token(); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Authorizer add_fact(org.biscuitsec.biscuit.token.builder.Fact fact) { - world.add_fact(Origin.authorizer(), fact.convert(symbols)); - return this; - } + Tuple2 t = res.get(); - public Authorizer add_fact(String s) throws Error.Parser { - Either> res = - Parser.fact(s); + return addRule(t._2); + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + public TrustedOrigins authorizerTrustedOrigins() { + return TrustedOrigins.fromScopes( + this.scopes, TrustedOrigins.defaultOrigins(), Long.MAX_VALUE, this.publicKeyToBlockId); + } - Tuple2 t = res.get(); + public Authorizer addCheck(org.biscuitsec.biscuit.token.builder.Check check) { + this.checks.add(check); + return this; + } - return this.add_fact(t._2); - } + public Authorizer addCheck(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.check(s); - public Authorizer add_rule(org.biscuitsec.biscuit.token.builder.Rule rule) { - org.biscuitsec.biscuit.datalog.Rule r = rule.convert(symbols); - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - r.scopes(), - this.authorizerTrustedOrigins(), - Long.MAX_VALUE, - this.publicKeyToBlockId - ); - world.add_rule(Long.MAX_VALUE, ruleTrustedOrigins, r); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public TrustedOrigins authorizerTrustedOrigins() { - return TrustedOrigins.fromScopes( - this.scopes, - TrustedOrigins.defaultOrigins(), - Long.MAX_VALUE, - this.publicKeyToBlockId - ); - } + Tuple2 t = res.get(); - public Authorizer add_rule(String s) throws Error.Parser { - Either> res = - Parser.rule(s); + return addCheck(t._2); + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + public Authorizer setTime() throws Error.Language { + world.addFact( + Origin.authorizer(), Utils.fact("time", List.of(Utils.date(new Date()))).convert(symbolTable)); + return this; + } - Tuple2 t = res.get(); + public List getRevocationIds() throws Error { + ArrayList ids = new ArrayList<>(); - return add_rule(t._2); - } + final org.biscuitsec.biscuit.token.builder.Rule getRevocationIds = + Utils.rule( + "revocation_id", + List.of(Utils.var("id")), + List.of(Utils.pred("revocation_id", List.of(Utils.var("id"))))); - public Authorizer add_check(org.biscuitsec.biscuit.token.builder.Check check) { - this.checks.add(check); - return this; - } + this.query(getRevocationIds).stream() + .forEach( + fact -> { + fact.terms().stream() + .forEach( + id -> { + if (id instanceof Term.Str) { + ids.add(((Term.Str) id).getValue()); + } + }); + }); - public Authorizer add_check(String s) throws Error.Parser { - Either> res = - Parser.check(s); + return ids; + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + public Authorizer allow() { + ArrayList q = new ArrayList<>(); - Tuple2 t = res.get(); + q.add( + Utils.constrainedRule( + "allow", + new ArrayList<>(), + new ArrayList<>(), + List.of(new Expression.Value(new Term.Bool(true))))); - return add_check(t._2); - } + this.policies.add(new Policy(q, Policy.Kind.ALLOW)); + return this; + } - public Authorizer set_time() throws Error.Language { - world.add_fact(Origin.authorizer(), Utils.fact("time", List.of(Utils.date(new Date()))).convert(symbols)); - return this; - } + public Authorizer deny() { + ArrayList q = new ArrayList<>(); - public List get_revocation_ids() throws Error { - ArrayList ids = new ArrayList<>(); + q.add( + Utils.constrainedRule( + "deny", + new ArrayList<>(), + new ArrayList<>(), + List.of(new Expression.Value(new Term.Bool(true))))); - final org.biscuitsec.biscuit.token.builder.Rule getRevocationIds = Utils.rule( - "revocation_id", - List.of(Utils.var("id")), - List.of(Utils.pred("revocation_id", List.of(Utils.var("id")))) - ); + this.policies.add(new Policy(q, Policy.Kind.DENY)); + return this; + } - this.query(getRevocationIds).stream().forEach(fact -> { - fact.terms().stream().forEach(id -> { - if (id instanceof Term.Str) { - ids.add(((Term.Str) id).getValue()); - } - }); - }); + public Authorizer addPolicy(String s) throws Error.Parser { + Either> res = + Parser.policy(s); - return ids; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Authorizer allow() { - ArrayList q = new ArrayList<>(); + Tuple2 t = res.get(); - q.add(Utils.constrained_rule( - "allow", - new ArrayList<>(), - new ArrayList<>(), - List.of(new Expression.Value(new Term.Bool(true))) - )); + this.policies.add(t._2); + return this; + } - this.policies.add(new Policy(q, Policy.Kind.Allow)); - return this; - } + public Authorizer addPolicy(Policy p) { + this.policies.add(p); + return this; + } + + public Authorizer addScope(Scope s) { + this.scopes.add(s); + return this; + } - public Authorizer deny() { - ArrayList q = new ArrayList<>(); + public Set query( + org.biscuitsec.biscuit.token.builder.Rule query) throws Error { + return this.query(query, new RunLimits()); + } - q.add(Utils.constrained_rule( - "deny", - new ArrayList<>(), - new ArrayList<>(), - List.of(new Expression.Value(new Term.Bool(true))) - )); + public Set query(String s) throws Error { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.rule(s); - this.policies.add(new Policy(q, Policy.Kind.Deny)); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Authorizer add_policy(String s) throws Error.Parser { - Either> res = - Parser.policy(s); + Tuple2 t = res.get(); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + return query(t._2); + } - Tuple2 t = res.get(); + public Set query( + org.biscuitsec.biscuit.token.builder.Rule query, RunLimits limits) throws Error { + world.run(limits, symbolTable); - this.policies.add(t._2); - return this; - } + org.biscuitsec.biscuit.datalog.Rule rule = query.convert(symbolTable); + TrustedOrigins ruleTrustedorigins = + TrustedOrigins.fromScopes( + rule.scopes(), + TrustedOrigins.defaultOrigins(), + Long.MAX_VALUE, + this.publicKeyToBlockId); - public Authorizer add_policy(Policy p) { - this.policies.add(p); - return this; - } + FactSet facts = world.queryRule(rule, Long.MAX_VALUE, ruleTrustedorigins, symbolTable); + Set s = new HashSet<>(); - public Authorizer add_scope(Scope s) { - this.scopes.add(s); - return this; + for (Iterator it = facts.stream().iterator(); + it.hasNext(); ) { + org.biscuitsec.biscuit.datalog.Fact f = it.next(); + s.add(org.biscuitsec.biscuit.token.builder.Fact.convertFrom(f, symbolTable)); } - public Set query(org.biscuitsec.biscuit.token.builder.Rule query) throws Error { - return this.query(query, new RunLimits()); - } + return s; + } - public Set query(String s) throws Error { - Either> res = - Parser.rule(s); + public Set query(String s, RunLimits limits) + throws Error { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.rule(s); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); + } - Tuple2 t = res.get(); + Tuple2 t = res.get(); - return query(t._2); - } + return query(t._2, limits); + } - public Set query(org.biscuitsec.biscuit.token.builder.Rule query, RunLimits limits) throws Error { - world.run(limits, symbols); + public Long authorize() throws Error { + return this.authorize(new RunLimits()); + } - org.biscuitsec.biscuit.datalog.Rule rule = query.convert(symbols); - TrustedOrigins ruleTrustedorigins = TrustedOrigins.fromScopes( - rule.scopes(), - TrustedOrigins.defaultOrigins(), - Long.MAX_VALUE, - this.publicKeyToBlockId - ); + public Long authorize(RunLimits limits) throws Error { + Instant timeLimit = Instant.now().plus(limits.getMaxTime()); + List errors = new LinkedList<>(); - FactSet facts = world.query_rule(rule, Long.MAX_VALUE, - ruleTrustedorigins, symbols); - Set s = new HashSet<>(); + TrustedOrigins authorizerTrustedOrigins = this.authorizerTrustedOrigins(); - for (Iterator it = facts.stream().iterator(); it.hasNext(); ) { - org.biscuitsec.biscuit.datalog.Fact f = it.next(); - s.add(org.biscuitsec.biscuit.token.builder.Fact.convert_from(f, symbols)); - } + world.run(limits, symbolTable); - return s; - } + for (int i = 0; i < this.checks.size(); i++) { + org.biscuitsec.biscuit.datalog.Check c = this.checks.get(i).convert(symbolTable); + boolean successful = false; - public Set query(String s, RunLimits limits) throws Error { - Either> res = - Parser.rule(s); + for (int j = 0; j < c.queries().size(); j++) { + boolean res = false; + org.biscuitsec.biscuit.datalog.Rule query = c.queries().get(j); + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + query.scopes(), authorizerTrustedOrigins, Long.MAX_VALUE, this.publicKeyToBlockId); + switch (c.kind()) { + case ONE: + res = world.queryMatch(query, Long.MAX_VALUE, ruleTrustedOrigins, symbolTable); + break; + case ALL: + res = world.queryMatchAll(query, ruleTrustedOrigins, symbolTable); + break; + default: + throw new RuntimeException("unmapped kind"); + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); + if (Instant.now().compareTo(timeLimit) >= 0) { + throw new Error.Timeout(); } - Tuple2 t = res.get(); + if (res) { + successful = true; + break; + } + } + + if (!successful) { + errors.add(new FailedCheck.FailedAuthorizer(i, symbolTable.formatCheck(c))); + } + } + + if (token != null) { + TrustedOrigins authorityTrustedOrigins = + TrustedOrigins.fromScopes( + token.authority.getScopes(), + TrustedOrigins.defaultOrigins(), + 0, + this.publicKeyToBlockId); + + for (int j = 0; j < token.authority.getChecks().size(); j++) { + boolean successful = false; + + org.biscuitsec.biscuit.token.builder.Check c = + org.biscuitsec.biscuit.token.builder.Check.convertFrom( + token.authority.getChecks().get(j), token.symbolTable); + org.biscuitsec.biscuit.datalog.Check check = c.convert(symbolTable); + + for (int k = 0; k < check.queries().size(); k++) { + boolean res = false; + org.biscuitsec.biscuit.datalog.Rule query = check.queries().get(k); + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + query.scopes(), authorityTrustedOrigins, 0, this.publicKeyToBlockId); + switch (check.kind()) { + case ONE: + res = world.queryMatch(query, (long) 0, ruleTrustedOrigins, symbolTable); + break; + case ALL: + res = world.queryMatchAll(query, ruleTrustedOrigins, symbolTable); + break; + default: + throw new RuntimeException("unmapped kind"); + } + + if (Instant.now().compareTo(timeLimit) >= 0) { + throw new Error.Timeout(); + } + + if (res) { + successful = true; + break; + } + } - return query(t._2, limits); + if (!successful) { + errors.add(new FailedCheck.FailedBlock(0, j, symbolTable.formatCheck(check))); + } + } } - public Long authorize() throws Error { - return this.authorize(new RunLimits()); - } + Option> policyResult = Option.none(); + policies_test: + for (int i = 0; i < this.policies.size(); i++) { + Policy policy = this.policies.get(i); - public Long authorize(RunLimits limits) throws Error { - Instant timeLimit = Instant.now().plus(limits.maxTime); - List errors = new LinkedList<>(); - Option> policy_result = Option.none(); - - TrustedOrigins authorizerTrustedOrigins = this.authorizerTrustedOrigins(); - - world.run(limits, symbols); - - for (int i = 0; i < this.checks.size(); i++) { - org.biscuitsec.biscuit.datalog.Check c = this.checks.get(i).convert(symbols); - boolean successful = false; - - for (int j = 0; j < c.queries().size(); j++) { - boolean res = false; - org.biscuitsec.biscuit.datalog.Rule query = c.queries().get(j); - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - query.scopes(), - authorizerTrustedOrigins, - Long.MAX_VALUE, - this.publicKeyToBlockId - ); - switch (c.kind()) { - case One: - res = world.query_match(query, Long.MAX_VALUE, ruleTrustedOrigins, symbols); - break; - case All: - res = world.query_match_all(query, ruleTrustedOrigins, symbols); - break; - } - - if (Instant.now().compareTo(timeLimit) >= 0) { - throw new Error.Timeout(); - } - - if (res) { - successful = true; - break; - } - } + for (int j = 0; j < policy.queries().size(); j++) { + org.biscuitsec.biscuit.datalog.Rule query = policy.queries().get(j).convert(symbolTable); + TrustedOrigins policyTrustedOrigins = + TrustedOrigins.fromScopes( + query.scopes(), authorizerTrustedOrigins, Long.MAX_VALUE, this.publicKeyToBlockId); + boolean res = world.queryMatch(query, Long.MAX_VALUE, policyTrustedOrigins, symbolTable); - if (!successful) { - errors.add(new FailedCheck.FailedAuthorizer(i, symbols.print_check(c))); - } + if (Instant.now().compareTo(timeLimit) >= 0) { + throw new Error.Timeout(); } - if (token != null) { - TrustedOrigins authorityTrustedOrigins = TrustedOrigins.fromScopes( - token.authority.scopes, - TrustedOrigins.defaultOrigins(), - 0, - this.publicKeyToBlockId - ); - - for (int j = 0; j < token.authority.checks.size(); j++) { - boolean successful = false; - - org.biscuitsec.biscuit.token.builder.Check c = org.biscuitsec.biscuit.token.builder.Check.convert_from(token.authority.checks.get(j), token.symbols); - org.biscuitsec.biscuit.datalog.Check check = c.convert(symbols); - - for (int k = 0; k < check.queries().size(); k++) { - boolean res = false; - org.biscuitsec.biscuit.datalog.Rule query = check.queries().get(k); - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - query.scopes(), - authorityTrustedOrigins, - 0, - this.publicKeyToBlockId - ); - switch (check.kind()) { - case One: - res = world.query_match(query, (long)0, ruleTrustedOrigins, symbols); - break; - case All: - res = world.query_match_all(query, ruleTrustedOrigins, symbols); - break; - } - - if (Instant.now().compareTo(timeLimit) >= 0) { - throw new Error.Timeout(); - } - - if (res) { - successful = true; - break; - } - } - - if (!successful) { - errors.add(new FailedCheck.FailedBlock(0, j, symbols.print_check(check))); - } - } + if (res) { + if (this.policies.get(i).kind() == Policy.Kind.ALLOW) { + policyResult = Option.some(Right(i)); + } else { + policyResult = Option.some(Left(i)); + } + break policies_test; } - - policies_test: - for (int i = 0; i < this.policies.size(); i++) { - Policy policy = this.policies.get(i); - - for (int j = 0; j < policy.queries.size(); j++) { - org.biscuitsec.biscuit.datalog.Rule query = policy.queries.get(j).convert(symbols); - TrustedOrigins policyTrustedOrigins = TrustedOrigins.fromScopes( - query.scopes(), - authorizerTrustedOrigins, - Long.MAX_VALUE, - this.publicKeyToBlockId - ); - boolean res = world.query_match(query, Long.MAX_VALUE, policyTrustedOrigins, symbols); - - if (Instant.now().compareTo(timeLimit) >= 0) { - throw new Error.Timeout(); - } - - if (res) { - if (this.policies.get(i).kind == Policy.Kind.Allow) { - policy_result = Option.some(Right(i)); - } else { - policy_result = Option.some(Left(i)); - } - break policies_test; - } - } + } + } + + if (token != null) { + for (int i = 0; i < token.blocks.size(); i++) { + org.biscuitsec.biscuit.token.Block b = token.blocks.get(i); + TrustedOrigins blockTrustedOrigins = + TrustedOrigins.fromScopes( + b.getScopes(), TrustedOrigins.defaultOrigins(), i + 1, this.publicKeyToBlockId); + SymbolTable blockSymbolTable = token.symbolTable; + if (b.getExternalKey().isDefined()) { + blockSymbolTable = new SymbolTable(b.getSymbolTable(), b.getPublicKeys()); } - if (token != null) { - for (int i = 0; i < token.blocks.size(); i++) { - org.biscuitsec.biscuit.token.Block b = token.blocks.get(i); - TrustedOrigins blockTrustedOrigins = TrustedOrigins.fromScopes( - b.scopes, - TrustedOrigins.defaultOrigins(), - i+1, - this.publicKeyToBlockId - ); - SymbolTable blockSymbols = token.symbols; - if(b.externalKey.isDefined()) { - blockSymbols = new SymbolTable(b.symbols.symbols, b.publicKeys()); - } - - for (int j = 0; j < b.checks.size(); j++) { - boolean successful = false; - - org.biscuitsec.biscuit.token.builder.Check c = org.biscuitsec.biscuit.token.builder.Check.convert_from(b.checks.get(j), blockSymbols); - org.biscuitsec.biscuit.datalog.Check check = c.convert(symbols); - - for (int k = 0; k < check.queries().size(); k++) { - boolean res = false; - org.biscuitsec.biscuit.datalog.Rule query = check.queries().get(k); - TrustedOrigins ruleTrustedOrigins = TrustedOrigins.fromScopes( - query.scopes(), - blockTrustedOrigins, - i+1, - this.publicKeyToBlockId - ); - switch (check.kind()) { - case One: - res = world.query_match(query, (long)i+1, ruleTrustedOrigins, symbols); - break; - case All: - res = world.query_match_all(query, ruleTrustedOrigins, symbols); - break; - } - - if (Instant.now().compareTo(timeLimit) >= 0) { - throw new Error.Timeout(); - } - - if (res) { - successful = true; - break; - } - } + for (int j = 0; j < b.getChecks().size(); j++) { + boolean successful = false; + + org.biscuitsec.biscuit.token.builder.Check c = + org.biscuitsec.biscuit.token.builder.Check.convertFrom( + b.getChecks().get(j), blockSymbolTable); + org.biscuitsec.biscuit.datalog.Check check = c.convert(symbolTable); + + for (int k = 0; k < check.queries().size(); k++) { + boolean res = false; + org.biscuitsec.biscuit.datalog.Rule query = check.queries().get(k); + TrustedOrigins ruleTrustedOrigins = + TrustedOrigins.fromScopes( + query.scopes(), blockTrustedOrigins, i + 1, this.publicKeyToBlockId); + switch (check.kind()) { + case ONE: + res = world.queryMatch(query, (long) i + 1, ruleTrustedOrigins, symbolTable); + break; + case ALL: + res = world.queryMatchAll(query, ruleTrustedOrigins, symbolTable); + break; + default: + throw new RuntimeException("unmapped kind"); + } - if (!successful) { - errors.add(new FailedCheck.FailedBlock(i + 1, j, symbols.print_check(check))); - } - } + if (Instant.now().compareTo(timeLimit) >= 0) { + throw new Error.Timeout(); } - } - if (policy_result.isDefined()) { - Either e = policy_result.get(); - if (e.isRight()) { - if (errors.isEmpty()) { - return e.get().longValue(); - } else { - throw new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(e.get()), errors)); - } - } else { - throw new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Deny(e.getLeft()), errors)); + if (res) { + successful = true; + break; } - } else { - throw new Error.FailedLogic(new LogicError.NoMatchingPolicy(errors)); + } + + if (!successful) { + errors.add(new FailedCheck.FailedBlock(i + 1, j, symbolTable.formatCheck(check))); + } } + } } - public String print_world() { - StringBuilder facts = new StringBuilder(); - for(Map.Entry> entry: this.world.facts().facts().entrySet()) { - facts.append("\n\t\t"+entry.getKey()+":"); - for(org.biscuitsec.biscuit.datalog.Fact f: entry.getValue()) { - facts.append("\n\t\t\t"); - facts.append(this.symbols.print_fact(f)); - } + if (policyResult.isDefined()) { + Either e = policyResult.get(); + if (e.isRight()) { + if (errors.isEmpty()) { + return e.get().longValue(); + } else { + throw new Error.FailedLogic( + new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(e.get()), errors)); } - final List rules = this.world.rules().stream().map((r) -> this.symbols.print_rule(r)).collect(Collectors.toList()); - - List checks = new ArrayList<>(); - - for (int j = 0; j < this.checks.size(); j++) { - checks.add("Authorizer[" + j + "]: " + this.checks.get(j).toString()); + } else { + throw new Error.FailedLogic( + new LogicError.Unauthorized(new LogicError.MatchedPolicy.Deny(e.getLeft()), errors)); + } + } else { + throw new Error.FailedLogic(new LogicError.NoMatchingPolicy(errors)); + } + } + + public String formatWorld() { + StringBuilder facts = new StringBuilder(); + for (Map.Entry> entry : + this.world.getFacts().facts().entrySet()) { + facts.append("\n\t\t" + entry.getKey() + ":"); + for (org.biscuitsec.biscuit.datalog.Fact f : entry.getValue()) { + facts.append("\n\t\t\t"); + facts.append(this.symbolTable.formatFact(f)); + } + } + final List rules = + this.world.getRules().stream() + .map((r) -> this.symbolTable.formatRule(r)) + .collect(Collectors.toList()); + + List checks = new ArrayList<>(); + + for (int j = 0; j < this.checks.size(); j++) { + checks.add("Authorizer[" + j + "]: " + this.checks.get(j).toString()); + } + + if (this.token != null) { + for (int j = 0; j < this.token.authority.getChecks().size(); j++) { + checks.add( + "Block[0][" + + j + + "]: " + + token.symbolTable.formatCheck(this.token.authority.getChecks().get(j))); + } + + for (int i = 0; i < this.token.blocks.size(); i++) { + Block b = this.token.blocks.get(i); + + SymbolTable blockSymbolTable = token.symbolTable; + if (b.getExternalKey().isDefined()) { + blockSymbolTable = new SymbolTable(b.getSymbolTable(), b.getPublicKeys()); } - if (this.token != null) { - for (int j = 0; j < this.token.authority.checks.size(); j++) { - checks.add("Block[0][" + j + "]: " + token.symbols.print_check(this.token.authority.checks.get(j))); - } + for (int j = 0; j < b.getChecks().size(); j++) { + checks.add( + "Block[" + (i + 1) + "][" + j + "]: " + blockSymbolTable.formatCheck(b.getChecks().get(j))); + } + } + } - for (int i = 0; i < this.token.blocks.size(); i++) { - Block b = this.token.blocks.get(i); + return "World {\n\tfacts: [" + + facts.toString() + // String.join(",\n\t\t", facts) + + + "\n\t],\n\trules: [\n\t\t" + + String.join(",\n\t\t", rules) + + "\n\t],\n\tchecks: [\n\t\t" + + String.join(",\n\t\t", checks) + + "\n\t]\n}"; + } - SymbolTable blockSymbols = token.symbols; - if(b.externalKey.isDefined()) { - blockSymbols = new SymbolTable(b.symbols.symbols, b.publicKeys()); - } + public FactSet getFacts() { + return this.world.getFacts(); + } - for (int j = 0; j < b.checks.size(); j++) { - checks.add("Block[" + (i+1) + "][" + j + "]: " + blockSymbols.print_check(b.checks.get(j))); - } - } - } + public RuleSet getRules() { + return this.world.getRules(); + } - return "World {\n\tfacts: [" + - facts.toString() + - //String.join(",\n\t\t", facts) + - "\n\t],\n\trules: [\n\t\t" + - String.join(",\n\t\t", rules) + - "\n\t],\n\tchecks: [\n\t\t" + - String.join(",\n\t\t", checks) + - "\n\t]\n}"; + public List>> getChecks() { + List>> allChecks = new ArrayList<>(); + if (!this.checks.isEmpty()) { + allChecks.add(new Tuple2<>(Long.MAX_VALUE, this.checks)); } - public FactSet facts() { - return this.world.facts(); + List authorityChecks = new ArrayList<>(); + for (org.biscuitsec.biscuit.datalog.Check check : this.token.authority.getChecks()) { + authorityChecks.add(Check.convertFrom(check, this.token.symbolTable)); } - - public RuleSet rules() { - return this.world.rules(); + if (!authorityChecks.isEmpty()) { + allChecks.add(new Tuple2<>((long) 0, authorityChecks)); } - public List>> checks() { - List>> allChecks = new ArrayList<>(); - if(!this.checks.isEmpty()) { - allChecks.add(new Tuple2<>(Long.MAX_VALUE, this.checks)); - } + long count = 1; + for (Block block : this.token.blocks) { + List blockChecks = new ArrayList<>(); - List authorityChecks = new ArrayList<>(); - for(org.biscuitsec.biscuit.datalog.Check check: this.token.authority.checks) { - authorityChecks.add(Check.convert_from(check, this.token.symbols)); + if (block.getExternalKey().isDefined()) { + SymbolTable blockSymbolTable = new SymbolTable(block.getSymbolTable(), block.getPublicKeys()); + for (org.biscuitsec.biscuit.datalog.Check check : block.getChecks()) { + blockChecks.add(Check.convertFrom(check, blockSymbolTable)); } - if(!authorityChecks.isEmpty()) { - allChecks.add(new Tuple2<>((long) 0, authorityChecks)); + } else { + for (org.biscuitsec.biscuit.datalog.Check check : block.getChecks()) { + blockChecks.add(Check.convertFrom(check, token.symbolTable)); } + } + if (!blockChecks.isEmpty()) { + allChecks.add(new Tuple2<>(count, blockChecks)); + } + count += 1; + } - long count = 1; - for(Block block: this.token.blocks) { - List blockChecks = new ArrayList<>(); - - if(block.externalKey.isDefined()) { - SymbolTable blockSymbols = new SymbolTable(block.symbols.symbols, block.publicKeys()); - for(org.biscuitsec.biscuit.datalog.Check check: block.checks) { - blockChecks.add(Check.convert_from(check, blockSymbols)); - } - } else { - for(org.biscuitsec.biscuit.datalog.Check check: block.checks) { - blockChecks.add(Check.convert_from(check, token.symbols)); - } - } - if(!blockChecks.isEmpty()) { - allChecks.add(new Tuple2<>(count, blockChecks)); - } - count += 1; - } + return allChecks; + } - return allChecks; - } + public List getPolicies() { + return this.policies; + } - public List policies() { - return this.policies; - } + public SymbolTable getSymbolTable() { + return symbolTable; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/Biscuit.java b/src/main/java/org/biscuitsec/biscuit/token/Biscuit.java index 1f71c01e..f78110ac 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Biscuit.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Biscuit.java @@ -1,394 +1,401 @@ package org.biscuitsec.biscuit.token; import biscuit.format.schema.Schema.PublicKey.Algorithm; +import io.vavr.Tuple2; +import io.vavr.control.Either; +import io.vavr.control.Option; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.SignatureException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; import org.biscuitsec.biscuit.crypto.KeyDelegate; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.format.SerializedBiscuit; -import io.vavr.Tuple2; -import io.vavr.control.Either; -import io.vavr.control.Option; - -import java.security.*; -import java.util.*; - -/** - * Biscuit auth token - */ -public class Biscuit extends UnverifiedBiscuit { - /** - * Creates a token builder - *

- * this function uses the default symbol table - * - * @param root root private key - * @return - */ - public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final org.biscuitsec.biscuit.crypto.Signer root) { - return new org.biscuitsec.biscuit.token.builder.Biscuit(new SecureRandom(), root); - } - - /** - * Creates a token builder - *

- * this function uses the default symbol table - * - * @param rng random number generator - * @param root root private key - * @return - */ - public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final KeyPair root) { - return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root); - } - - /** - * Creates a token builder - * - * @param rng random number generator - * @param root root private key - * @return - */ - public static org.biscuitsec.biscuit.token.builder.Biscuit builder(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, final Option root_key_id) { - return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, root_key_id); - } - - /** - * Creates a token - * - * @param rng random number generator - * @param root root private key - * @param authority authority block - * @return Biscuit - */ - public static Biscuit make(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, final Block authority) throws Error.FormatError { - return Biscuit.make(rng, root, Option.none(), authority); - } - - /** - * Creates a token - * - * @param rng random number generator - * @param root root private key - * @param authority authority block - * @return Biscuit - */ - public static Biscuit make(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, final Integer root_key_id, final Block authority) throws Error.FormatError { - return Biscuit.make(rng, root, Option.of(root_key_id), authority); - } - - /** - * Creates a token - * - * @param rng random number generator - * @param root root private key - * @param authority authority block - * @return Biscuit - */ - static private Biscuit make(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, final Option root_key_id, final Block authority) throws Error.FormatError { - ArrayList blocks = new ArrayList<>(); - - KeyPair next = KeyPair.generate(root.public_key().algorithm, rng); - - for(PublicKey pk: authority.publicKeys) { - authority.symbols.insert(pk); - } - - Either container = SerializedBiscuit.make(root, root_key_id, authority, next); - if (container.isLeft()) { - throw container.getLeft(); - } else { - SerializedBiscuit s = container.get(); - List revocation_ids = s.revocation_identifiers(); - - Option c = Option.some(s); - return new Biscuit(authority, blocks, authority.symbols, s, revocation_ids); - } - } - - Biscuit(Block authority, List blocks, SymbolTable symbols, SerializedBiscuit serializedBiscuit, - List revocation_ids) { - super(authority, blocks, symbols, serializedBiscuit, revocation_ids); - } - - /** - * Deserializes a Biscuit token from a hex string - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - *

- * This method uses the default symbol table - * - * @param data - * @return - */ - @Deprecated - static public Biscuit from_b64(String data, PublicKey root) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return Biscuit.from_bytes(Base64.getUrlDecoder().decode(data), root); - } - /** - * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - *

- * This method uses the default symbol table - * - * @param data - * @return Biscuit - */ - static public Biscuit from_b64url(String data, PublicKey root) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return Biscuit.from_bytes(Base64.getUrlDecoder().decode(data), root); +/** Biscuit auth token */ +public final class Biscuit extends UnverifiedBiscuit { + /** + * Creates a token builder + * + *

this function uses the default symbol table + * + * @param root root private key + * @return + */ + public static org.biscuitsec.biscuit.token.builder.Biscuit builder( + final org.biscuitsec.biscuit.crypto.Signer root) { + return new org.biscuitsec.biscuit.token.builder.Biscuit(new SecureRandom(), root); + } + + /** + * Creates a token builder + * + *

this function uses the default symbol table + * + * @param rng random number generator + * @param root root private key + * @return + */ + public static org.biscuitsec.biscuit.token.builder.Biscuit builder( + final SecureRandom rng, final KeyPair root) { + return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root); + } + + /** + * Creates a token builder + * + * @param rng random number generator + * @param root root private key + * @return + */ + public static org.biscuitsec.biscuit.token.builder.Biscuit builder( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + final Option rootKeyId) { + return new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root, rootKeyId); + } + + /** + * Creates a token + * + * @param rng random number generator + * @param root root private key + * @param authority authority block + * @return Biscuit + */ + public static Biscuit make( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + final Block authority) + throws Error.FormatError { + return Biscuit.make(rng, root, Option.none(), authority); + } + + /** + * Creates a token + * + * @param rng random number generator + * @param root root private key + * @param authority authority block + * @return Biscuit + */ + public static Biscuit make( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + final Integer rootKeyId, + final Block authority) + throws Error.FormatError { + return Biscuit.make(rng, root, Option.of(rootKeyId), authority); + } + + /** + * Creates a token + * + * @param rng random number generator + * @param root root private key + * @param authority authority block + * @return Biscuit + */ + private static Biscuit make( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + final Option rootKeyId, + final Block authority) + throws Error.FormatError { + ArrayList blocks = new ArrayList<>(); + + KeyPair next = KeyPair.generate(root.getPublicKey().getAlgorithm(), rng); + + for (PublicKey pk : authority.getPublicKeys()) { + authority.getSymbolTable().insert(pk); } - /** - * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - *

- * This method uses the default symbol table - * - * @param data - * @return Biscuit - */ - static public Biscuit from_b64url(String data, KeyDelegate delegate) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return Biscuit.from_bytes(Base64.getUrlDecoder().decode(data), delegate); - } + Either container = + SerializedBiscuit.make(root, rootKeyId, authority, next); + if (container.isLeft()) { + throw container.getLeft(); + } else { + SerializedBiscuit s = container.get(); + List revocationIds = s.revocationIdentifiers(); - /** - * Deserializes a Biscuit token from a byte array - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - *

- * This method uses the default symbol table - * - * @param data - * @return - */ - static public Biscuit from_bytes(byte[] data, PublicKey root) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return from_bytes_with_symbols(data, root, default_symbol_table()); + Option c = Option.some(s); + return new Biscuit(authority, blocks, authority.getSymbolTable(), s, revocationIds); } - - /** - * Deserializes a Biscuit token from a byte array - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - *

- * This method uses the default symbol table - * - * @param data - * @return - */ - static public Biscuit from_bytes(byte[] data, KeyDelegate delegate) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return from_bytes_with_symbols(data, delegate, default_symbol_table()); + } + + Biscuit( + Block authority, + List blocks, + SymbolTable symbolTable, + SerializedBiscuit serializedBiscuit, + List revocationIds) { + super(authority, blocks, symbolTable, serializedBiscuit, revocationIds); + } + + /** + * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + *

This method uses the default symbol table + * + * @param data + * @return Biscuit + */ + public static Biscuit fromBase64Url(String data, PublicKey root) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + return Biscuit.fromBytes(Base64.getUrlDecoder().decode(data), root); + } + + /** + * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + *

This method uses the default symbol table + * + * @param data + * @return Biscuit + */ + public static Biscuit fromBase64Url(String data, KeyDelegate delegate) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + return Biscuit.fromBytes(Base64.getUrlDecoder().decode(data), delegate); + } + + /** + * Deserializes a Biscuit token from a byte array + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + *

This method uses the default symbol table + * + * @param data + * @return + */ + public static Biscuit fromBytes(byte[] data, PublicKey root) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + return fromBytesWithSymbols(data, root, defaultSymbolTable()); + } + + /** + * Deserializes a Biscuit token from a byte array + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + *

This method uses the default symbol table + * + * @param data + * @return + */ + public static Biscuit fromBytes(byte[] data, KeyDelegate delegate) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + return fromBytesWithSymbols(data, delegate, defaultSymbolTable()); + } + + /** + * Deserializes a Biscuit token from a byte array + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + * @param data + * @return + */ + public static Biscuit fromBytesWithSymbols(byte[] data, PublicKey root, SymbolTable symbolTable) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + // System.out.println("will deserialize and verify token"); + SerializedBiscuit ser = SerializedBiscuit.fromBytes(data, root); + // System.out.println("deserialized token, will populate Biscuit structure"); + + return Biscuit.fromSerializedBiscuit(ser, symbolTable); + } + + /** + * Deserializes a Biscuit token from a byte array + * + *

This checks the signature, but does not verify that the first key is the root key, to allow + * appending blocks without knowing about the root key. + * + *

The root key check is performed in the verify method + * + * @param data + * @return + */ + public static Biscuit fromBytesWithSymbols(byte[] data, KeyDelegate delegate, SymbolTable symbolTable) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + // System.out.println("will deserialize and verify token"); + SerializedBiscuit ser = SerializedBiscuit.fromBytes(data, delegate); + // System.out.println("deserialized token, will populate Biscuit structure"); + + return Biscuit.fromSerializedBiscuit(ser, symbolTable); + } + + /** + * Fills a Biscuit structure from a deserialized token + * + * @return + */ + static Biscuit fromSerializedBiscuit(SerializedBiscuit ser, SymbolTable symbolTable) throws Error { + Tuple2> t = ser.extractBlocks(symbolTable); + Block authority = t._1; + ArrayList blocks = t._2; + + List revocationIds = ser.revocationIdentifiers(); + + return new Biscuit(authority, blocks, symbolTable, ser, revocationIds); + } + + /** + * Creates a authorizer for this token + * + *

This function checks that the root key is the one we expect + * + * @return + */ + public Authorizer authorizer() throws Error.FailedLogic { + return Authorizer.make(this); + } + + /** + * Serializes a token to a byte array + * + * @return + */ + public byte[] serialize() throws Error.FormatError.SerializationError { + return this.serializedBiscuit.serialize(); + } + + /** + * Serializes a token to base 64 url String using RFC4648_URLSAFE + * + * @return String + * @throws Error.FormatError.SerializationError + */ + public String serializeBase64Url() throws Error.FormatError.SerializationError { + return Base64.getUrlEncoder().encodeToString(serialize()); + } + + /** + * Generates a new token from an existing one and a new block + * + * @param block new block (should be generated from a Block builder) + * @param algorithm algorithm to use for the ephemeral key pair + * @return + */ + public Biscuit attenuate(org.biscuitsec.biscuit.token.builder.Block block, Algorithm algorithm) + throws Error { + SecureRandom rng = new SecureRandom(); + KeyPair keypair = KeyPair.generate(algorithm, rng); + SymbolTable builderSymbols = new SymbolTable(this.symbolTable); + return attenuate(rng, keypair, block.build(builderSymbols)); + } + + public Biscuit attenuate( + final SecureRandom rng, + final KeyPair keypair, + org.biscuitsec.biscuit.token.builder.Block block) + throws Error { + SymbolTable builderSymbols = new SymbolTable(this.symbolTable); + return attenuate(rng, keypair, block.build(builderSymbols)); + } + + /** + * Generates a new token from an existing one and a new block + * + * @param rng random number generator + * @param keypair ephemeral key pair + * @param block new block (should be generated from a Block builder) + * @return + */ + public Biscuit attenuate(final SecureRandom rng, final KeyPair keypair, Block block) + throws Error { + Biscuit copiedBiscuit = this.copy(); + + if (!copiedBiscuit.symbolTable.disjoint(block.getSymbolTable())) { + throw new Error.SymbolTableOverlap(); } - /** - * Deserializes a Biscuit token from a byte array - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - * - * @param data - * @return - */ - static public Biscuit from_bytes_with_symbols(byte[] data, PublicKey root, SymbolTable symbols) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - //System.out.println("will deserialize and verify token"); - SerializedBiscuit ser = SerializedBiscuit.from_bytes(data, root); - //System.out.println("deserialized token, will populate Biscuit structure"); - - return Biscuit.from_serialized_biscuit(ser, symbols); + Either containerRes = + copiedBiscuit.serializedBiscuit.append(keypair, block, Option.none()); + if (containerRes.isLeft()) { + throw containerRes.getLeft(); } - /** - * Deserializes a Biscuit token from a byte array - *

- * This checks the signature, but does not verify that the first key is the root key, - * to allow appending blocks without knowing about the root key. - *

- * The root key check is performed in the verify method - * - * @param data - * @return - */ - static public Biscuit from_bytes_with_symbols(byte[] data, KeyDelegate delegate, SymbolTable symbols) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - //System.out.println("will deserialize and verify token"); - SerializedBiscuit ser = SerializedBiscuit.from_bytes(data, delegate); - //System.out.println("deserialized token, will populate Biscuit structure"); - - return Biscuit.from_serialized_biscuit(ser, symbols); + SymbolTable symbolTable = new SymbolTable(copiedBiscuit.symbolTable); + for (String s : block.getSymbolTable().symbols()) { + symbolTable.add(s); } - /** - * Fills a Biscuit structure from a deserialized token - * - * @return - */ - static Biscuit from_serialized_biscuit(SerializedBiscuit ser, SymbolTable symbols) throws Error { - Tuple2> t = ser.extractBlocks(symbols); - Block authority = t._1; - ArrayList blocks = t._2; - - List revocation_ids = ser.revocation_identifiers(); - - return new Biscuit(authority, blocks, symbols, ser, revocation_ids); + for (PublicKey pk : block.getPublicKeys()) { + symbolTable.insert(pk); } - /** - * Creates a authorizer for this token - *

- * This function checks that the root key is the one we expect - * - * @return - */ - public Authorizer authorizer() throws Error.FailedLogic { - return Authorizer.make(this); + ArrayList blocks = new ArrayList<>(); + for (Block b : copiedBiscuit.blocks) { + blocks.add(b); } - - /** - * Serializes a token to a byte array - * - * @return - */ - public byte[] serialize() throws Error.FormatError.SerializationError { - return this.serializedBiscuit.serialize(); + blocks.add(block); + + SerializedBiscuit container = containerRes.get(); + List revocationIds = container.revocationIdentifiers(); + + return new Biscuit(copiedBiscuit.authority, blocks, symbolTable, container, revocationIds); + } + + /** Generates a third party block request from a token */ + public Biscuit appendThirdPartyBlock(PublicKey externalKey, ThirdPartyBlockContents blockResponse) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + UnverifiedBiscuit b = super.appendThirdPartyBlock(externalKey, blockResponse); + + // no need to verify again, we are already working from a verified token + return Biscuit.fromSerializedBiscuit(b.serializedBiscuit, b.symbolTable); + } + + /** Prints a token's content */ + public String print() { + StringBuilder s = new StringBuilder(); + s.append("Biscuit {\n\tsymbols: "); + s.append(this.symbolTable.getAllSymbols()); + s.append("\n\tpublic keys: "); + s.append(this.symbolTable.getPublicKeys()); + s.append("\n\tauthority: "); + s.append(this.authority.print(this.symbolTable)); + s.append("\n\tblocks: [\n"); + for (Block b : this.blocks) { + s.append("\t\t"); + if (b.getExternalKey().isDefined()) { + s.append(b.print(b.getSymbolTable())); + } else { + s.append(b.print(this.symbolTable)); + } + s.append("\n"); } + s.append("\t]\n}"); - /** - * Serializes a token to a base 64 String - * - * @return - */ - @Deprecated - public String serialize_b64() throws Error.FormatError.SerializationError { - return Base64.getUrlEncoder().encodeToString(serialize()); - } - - /** - * Serializes a token to base 64 url String using RFC4648_URLSAFE - * - * @return String - * @throws Error.FormatError.SerializationError - */ - public String serialize_b64url() throws Error.FormatError.SerializationError { - return Base64.getUrlEncoder().encodeToString(serialize()); - } - - /** - * Generates a new token from an existing one and a new block - * - * @param block new block (should be generated from a Block builder) - * @param algorithm algorithm to use for the ephemeral key pair - * @return - */ - public Biscuit attenuate(org.biscuitsec.biscuit.token.builder.Block block, Algorithm algorithm) throws Error { - SecureRandom rng = new SecureRandom(); - KeyPair keypair = KeyPair.generate(algorithm, rng); - SymbolTable builderSymbols = new SymbolTable(this.symbols); - return attenuate(rng, keypair, block.build(builderSymbols)); - } + return s.toString(); + } - public Biscuit attenuate(final SecureRandom rng, final KeyPair keypair, org.biscuitsec.biscuit.token.builder.Block block) throws Error { - SymbolTable builderSymbols = new SymbolTable(this.symbols); - return attenuate(rng, keypair, block.build(builderSymbols)); - } - - /** - * Generates a new token from an existing one and a new block - * - * @param rng random number generator - * @param keypair ephemeral key pair - * @param block new block (should be generated from a Block builder) - * @return - */ - public Biscuit attenuate(final SecureRandom rng, final KeyPair keypair, Block block) throws Error { - Biscuit copiedBiscuit = this.copy(); - - if (!Collections.disjoint(copiedBiscuit.symbols.symbols, block.symbols.symbols)) { - throw new Error.SymbolTableOverlap(); - } - - Either containerRes = copiedBiscuit.serializedBiscuit.append(keypair, block, Option.none()); - if (containerRes.isLeft()) { - throw containerRes.getLeft(); - } - SerializedBiscuit container = containerRes.get(); - - SymbolTable symbols = new SymbolTable(copiedBiscuit.symbols); - for (String s : block.symbols.symbols) { - symbols.add(s); - } - - for(PublicKey pk: block.publicKeys) { - symbols.insert(pk); - } - - ArrayList blocks = new ArrayList<>(); - for (Block b : copiedBiscuit.blocks) { - blocks.add(b); - } - blocks.add(block); - - List revocation_ids = container.revocation_identifiers(); - - return new Biscuit(copiedBiscuit.authority, blocks, symbols, container, revocation_ids); - } - - /** - * Generates a third party block request from a token - */ - public Biscuit appendThirdPartyBlock(PublicKey externalKey, ThirdPartyBlockContents blockResponse) - throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - UnverifiedBiscuit b = super.appendThirdPartyBlock(externalKey, blockResponse); - - // no need to verify again, we are already working from a verified token - return Biscuit.from_serialized_biscuit(b.serializedBiscuit, b.symbols); - } - - /** - * Prints a token's content - */ - public String print() { - StringBuilder s = new StringBuilder(); - s.append("Biscuit {\n\tsymbols: "); - s.append(this.symbols.getAllSymbols()); - s.append("\n\tpublic keys: "); - s.append(this.symbols.publicKeys()); - s.append("\n\tauthority: "); - s.append(this.authority.print(this.symbols)); - s.append("\n\tblocks: [\n"); - for (Block b : this.blocks) { - s.append("\t\t"); - if(b.externalKey.isDefined()) { - s.append(b.print(b.symbols)); - } else { - s.append(b.print(this.symbols)); - } - s.append("\n"); - } - s.append("\t]\n}"); - - return s.toString(); - } - - public Biscuit copy() throws Error { - return Biscuit.from_serialized_biscuit(this.serializedBiscuit, this.symbols); - } + public Biscuit copy() throws Error { + return Biscuit.fromSerializedBiscuit(this.serializedBiscuit, this.symbolTable); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/Block.java b/src/main/java/org/biscuitsec/biscuit/token/Block.java index 1852271d..c87ba63d 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Block.java @@ -1,422 +1,504 @@ package org.biscuitsec.biscuit.token; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.expressions.Expression; -import org.biscuitsec.biscuit.datalog.expressions.Op; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.datalog.*; -import org.biscuitsec.biscuit.token.format.SerializedBiscuit; import com.google.protobuf.InvalidProtocolBufferException; import io.vavr.control.Either; import io.vavr.control.Option; - import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.Check; +import org.biscuitsec.biscuit.datalog.Fact; +import org.biscuitsec.biscuit.datalog.Rule; +import org.biscuitsec.biscuit.datalog.SchemaVersion; +import org.biscuitsec.biscuit.datalog.Scope; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.datalog.expressions.Op; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.token.format.SerializedBiscuit; -import static io.vavr.API.Left; -import static io.vavr.API.Right; - -/** - * Represents a token's block with its checks - */ -public class Block { - final SymbolTable symbols; - final String context; - final List facts; - final List rules; - final List checks; - final List scopes; - final List publicKeys; - Option externalKey; - long version; - - /** - * creates a new block - * - * @param base_symbols - */ - public Block(SymbolTable base_symbols) { - this.symbols = base_symbols; - this.context = ""; - this.facts = new ArrayList<>(); - this.rules = new ArrayList<>(); - this.checks = new ArrayList<>(); - this.scopes = new ArrayList<>(); - this.publicKeys = new ArrayList<>(); - this.externalKey = Option.none(); - } - - /** - * creates a new block - * - * @param base_symbols - * @param facts - * @param checks - */ - public Block(SymbolTable base_symbols, String context, List facts, List rules, List checks, - List scopes, List publicKeys, Option externalKey, int version) { - this.symbols = base_symbols; - this.context = context; - this.facts = facts; - this.rules = rules; - this.checks = checks; - this.scopes = scopes; - this.publicKeys = publicKeys; - this.externalKey = externalKey; - } - - public SymbolTable symbols() { - return symbols; - } - - public List publicKeys() { - return publicKeys; - } - - public void setExternalKey(PublicKey externalKey) { - this.externalKey = Option.some(externalKey); - } - - /** - * pretty printing for a block - * - * @param symbol_table - * @return - */ - public String print(SymbolTable symbol_table) { - StringBuilder s = new StringBuilder(); - - SymbolTable local_symbols; - if(this.externalKey.isDefined()) { - local_symbols = new SymbolTable(this.symbols); - for(PublicKey pk: symbol_table.publicKeys()) { - local_symbols.insert(pk); - } - } else { - local_symbols = symbol_table; - } - s.append("Block"); - s.append(" {\n\t\tsymbols: "); - s.append(this.symbols.symbols); - s.append("\n\t\tsymbol public keys: "); - s.append(this.symbols.publicKeys()); - s.append("\n\t\tblock public keys: "); - s.append(this.publicKeys); - s.append("\n\t\tcontext: "); - s.append(this.context); - if(this.externalKey.isDefined()) { - s.append("\n\t\texternal key: "); - s.append(this.externalKey.get().toString()); - } - s.append("\n\t\tscopes: ["); - for (Scope scope : this.scopes) { - s.append("\n\t\t\t"); - s.append(symbol_table.print_scope(scope)); - } - s.append("\n\t\t]\n\t\tfacts: ["); - for (Fact f : this.facts) { - s.append("\n\t\t\t"); - s.append(local_symbols.print_fact(f)); - } - s.append("\n\t\t]\n\t\trules: ["); - for (Rule r : this.rules) { - s.append("\n\t\t\t"); - s.append(local_symbols.print_rule(r)); - } - s.append("\n\t\t]\n\t\tchecks: ["); - for (Check c : this.checks) { - s.append("\n\t\t\t"); - s.append(local_symbols.print_check(c)); - } - s.append("\n\t\t]\n\t}"); - - return s.toString(); +/** Represents a token's block with its checks */ +public final class Block { + private final SymbolTable symbolTable; + private final String context; + private final List facts; + private final List rules; + private final List checks; + private final List scopes; + private final List publicKeys; + private Option externalKey; + private long version; + + /** + * creates a new block + * + * @param baseSymbols + */ + public Block(SymbolTable baseSymbols) { + this.symbolTable = baseSymbols; + this.context = ""; + this.facts = new ArrayList<>(); + this.rules = new ArrayList<>(); + this.checks = new ArrayList<>(); + this.scopes = new ArrayList<>(); + this.publicKeys = new ArrayList<>(); + this.externalKey = Option.none(); + } + + /** + * creates a new block + * + * @param baseSymbols + * @param facts + * @param checks + */ + public Block( + SymbolTable baseSymbols, + String context, + List facts, + List rules, + List checks, + List scopes, + List publicKeys, + Option externalKey, + int version) { + this.symbolTable = baseSymbols; + this.context = context; + this.facts = facts; + this.rules = rules; + this.checks = checks; + this.scopes = scopes; + this.publicKeys = publicKeys; + this.externalKey = externalKey; + } + + public SymbolTable getSymbolTable() { + return this.symbolTable; + } + + public void setExternalKey(PublicKey externalKey) { + this.externalKey = Option.some(externalKey); + } + + /** + * pretty printing for a block + * + * @param symbolTable + * @return + */ + public String print(SymbolTable symbolTable) { + StringBuilder s = new StringBuilder(); + + SymbolTable localSymbols; + if (this.externalKey.isDefined()) { + localSymbols = new SymbolTable(this.symbolTable); + for (PublicKey pk : symbolTable.getPublicKeys()) { + localSymbols.insert(pk); + } + } else { + localSymbols = symbolTable; + } + s.append("Block"); + s.append(" {\n\t\tsymbols: "); + s.append(this.getSymbolTable().symbols()); + s.append("\n\t\tsymbol public keys: "); + s.append(this.symbolTable.getPublicKeys()); + s.append("\n\t\tblock public keys: "); + s.append(this.publicKeys); + s.append("\n\t\tcontext: "); + s.append(this.context); + if (this.externalKey.isDefined()) { + s.append("\n\t\texternal key: "); + s.append(this.externalKey.get().toString()); + } + s.append("\n\t\tscopes: ["); + for (Scope scope : this.scopes) { + s.append("\n\t\t\t"); + s.append(symbolTable.formatScope(scope)); + } + s.append("\n\t\t]\n\t\tfacts: ["); + for (Fact f : this.facts) { + s.append("\n\t\t\t"); + s.append(localSymbols.formatFact(f)); + } + s.append("\n\t\t]\n\t\trules: ["); + for (Rule r : this.rules) { + s.append("\n\t\t\t"); + s.append(localSymbols.formatRule(r)); + } + s.append("\n\t\t]\n\t\tchecks: ["); + for (Check c : this.checks) { + s.append("\n\t\t\t"); + s.append(localSymbols.formatCheck(c)); + } + s.append("\n\t\t]\n\t}"); + + return s.toString(); + } + + public String printCode(SymbolTable symbolTable) { + StringBuilder s = new StringBuilder(); + + SymbolTable localSymbols; + if (this.externalKey.isDefined()) { + localSymbols = new SymbolTable(this.symbolTable); + for (PublicKey pk : symbolTable.getPublicKeys()) { + localSymbols.insert(pk); + } + } else { + localSymbols = symbolTable; + } + /*s.append("Block"); + s.append(" {\n\t\tsymbols: "); + s.append(this.symbols.symbols); + s.append("\n\t\tsymbol public keys: "); + s.append(this.symbols.publicKeys()); + s.append("\n\t\tblock public keys: "); + s.append(this.publicKeys); + s.append("\n\t\tcontext: "); + s.append(this.context); + if (this.externalKey.isDefined()) { + s.append("\n\t\texternal key: "); + s.append(this.externalKey.get().toString()); + }*/ + for (Scope scope : this.scopes) { + s.append("trusting " + localSymbols.formatScope(scope) + "\n"); + } + for (Fact f : this.facts) { + s.append(localSymbols.formatFact(f) + ";\n"); + } + for (Rule r : this.rules) { + s.append(localSymbols.formatRule(r) + ";\n"); + } + for (Check c : this.checks) { + s.append(localSymbols.formatCheck(c) + ";\n"); } - public String printCode(SymbolTable symbol_table) { - StringBuilder s = new StringBuilder(); + return s.toString(); + } - SymbolTable local_symbols; - if(this.externalKey.isDefined()) { - local_symbols = new SymbolTable(this.symbols); - for(PublicKey pk: symbol_table.publicKeys()) { - local_symbols.insert(pk); - } - } else { - local_symbols = symbol_table; - } - /*s.append("Block"); - s.append(" {\n\t\tsymbols: "); - s.append(this.symbols.symbols); - s.append("\n\t\tsymbol public keys: "); - s.append(this.symbols.publicKeys()); - s.append("\n\t\tblock public keys: "); - s.append(this.publicKeys); - s.append("\n\t\tcontext: "); - s.append(this.context); - if(this.externalKey.isDefined()) { - s.append("\n\t\texternal key: "); - s.append(this.externalKey.get().toString()); - }*/ - for (Scope scope : this.scopes) { - s.append("trusting "+local_symbols.print_scope(scope)+"\n"); - } - for (Fact f : this.facts) { - s.append(local_symbols.print_fact(f)+";\n"); - } - for (Rule r : this.rules) { - s.append(local_symbols.print_rule(r)+";\n"); - } - for (Check c : this.checks) { - s.append(local_symbols.print_check(c)+";\n"); - } + /** + * Serializes a Block to its Protobuf representation + * + * @return + */ + public Schema.Block serialize() { + Schema.Block.Builder b = Schema.Block.newBuilder(); - return s.toString(); + for (int i = 0; i < this.getSymbolTable().symbols().size(); i++) { + b.addSymbols(this.getSymbolTable().symbols().get(i)); } - /** - * Serializes a Block to its Protobuf representation - * - * @return - */ - public Schema.Block serialize() { - Schema.Block.Builder b = Schema.Block.newBuilder(); + if (!this.context.isEmpty()) { + b.setContext(this.context); + } - for (int i = 0; i < this.symbols.symbols.size(); i++) { - b.addSymbols(this.symbols.symbols.get(i)); - } + for (int i = 0; i < this.facts.size(); i++) { + b.addFactsV2(this.facts.get(i).serialize()); + } - if (!this.context.isEmpty()) { - b.setContext(this.context); - } + for (int i = 0; i < this.rules.size(); i++) { + b.addRulesV2(this.rules.get(i).serialize()); + } - for (int i = 0; i < this.facts.size(); i++) { - b.addFactsV2(this.facts.get(i).serialize()); - } + for (int i = 0; i < this.checks.size(); i++) { + b.addChecksV2(this.checks.get(i).serialize()); + } - for (int i = 0; i < this.rules.size(); i++) { - b.addRulesV2(this.rules.get(i).serialize()); - } + for (Scope scope : this.scopes) { + b.addScope(scope.serialize()); + } - for (int i = 0; i < this.checks.size(); i++) { - b.addChecksV2(this.checks.get(i).serialize()); - } + for (PublicKey pk : this.publicKeys) { + b.addPublicKeys(pk.serialize()); + } - for (Scope scope: this.scopes) { - b.addScope(scope.serialize()); - } + b.setVersion(getSchemaVersion()); + return b.build(); + } - for(PublicKey pk: this.publicKeys) { - b.addPublicKeys(pk.serialize()); - } + int getSchemaVersion() { + boolean containsScopes = !this.scopes.isEmpty(); + boolean containsCheckAll = false; + boolean containsV4 = false; - b.setVersion(getSchemaVersion()); - return b.build(); + for (Rule r : this.rules) { + containsScopes |= !r.scopes().isEmpty(); + for (Expression e : r.expressions()) { + containsV4 |= containsV4Op(e); + } } + for (Check c : this.checks) { + containsCheckAll |= c.kind() == Check.Kind.ALL; - int getSchemaVersion() { - boolean containsScopes = !this.scopes.isEmpty(); - boolean containsCheckAll = false; - boolean containsV4 = false; - - for (Rule r: this.rules) { - containsScopes |= !r.scopes().isEmpty(); - for(Expression e: r.expressions()) { - containsV4 |= containsV4Op(e); - } - } - for(Check c: this.checks) { - containsCheckAll |= c.kind() == Check.Kind.All; - - for (Rule q: c.queries()) { - containsScopes |= !q.scopes().isEmpty(); - for(Expression e: q.expressions()) { - containsV4 |= containsV4Op(e); - } - } + for (Rule q : c.queries()) { + containsScopes |= !q.scopes().isEmpty(); + for (Expression e : q.expressions()) { + containsV4 |= containsV4Op(e); } + } + } - if(this.externalKey.isDefined()) { - return SerializedBiscuit.MAX_SCHEMA_VERSION; + if (this.externalKey.isDefined()) { + return SerializedBiscuit.MAX_SCHEMA_VERSION; - }else if(containsScopes || containsCheckAll || containsV4) { - return 4; - } else { - return SerializedBiscuit.MIN_SCHEMA_VERSION; - } + } else if (containsScopes || containsCheckAll || containsV4) { + return 4; + } else { + return SerializedBiscuit.MIN_SCHEMA_VERSION; + } + } + + boolean containsV4Op(Expression e) { + for (Op op : e.getOps()) { + if (op instanceof Op.Binary) { + Op.BinaryOp o = ((Op.Binary) op).getOp(); + if (o == Op.BinaryOp.BitwiseAnd + || o == Op.BinaryOp.BitwiseOr + || o == Op.BinaryOp.BitwiseXor + || o == Op.BinaryOp.NotEqual) { + return true; + } + } } - boolean containsV4Op(Expression e) { - for (Op op: e.getOps()) { - if (op instanceof Op.Binary) { - Op.BinaryOp o = ((Op.Binary) op).getOp(); - if (o == Op.BinaryOp.BitwiseAnd || o == Op.BinaryOp.BitwiseOr || o == Op.BinaryOp.BitwiseXor || o == Op.BinaryOp.NotEqual) { - return true; - } - } - } - - return false; + return false; + } + + /** + * Deserializes a block from its Protobuf representation + * + * @param b + * @return + */ + public static Either deserialize( + Schema.Block b, Option externalKey) { + int version = b.getVersion(); + if (version < SerializedBiscuit.MIN_SCHEMA_VERSION + || version > SerializedBiscuit.MAX_SCHEMA_VERSION) { + return Left( + new Error.FormatError.Version( + SerializedBiscuit.MIN_SCHEMA_VERSION, SerializedBiscuit.MAX_SCHEMA_VERSION, version)); } - /** - * Deserializes a block from its Protobuf representation - * - * @param b - * @return - */ - static public Either deserialize(Schema.Block b, Option externalKey) { - int version = b.getVersion(); - if (version < SerializedBiscuit.MIN_SCHEMA_VERSION || version > SerializedBiscuit.MAX_SCHEMA_VERSION) { - return Left(new Error.FormatError.Version(SerializedBiscuit.MIN_SCHEMA_VERSION, SerializedBiscuit.MAX_SCHEMA_VERSION, version)); - } + SymbolTable newSymbolTable = new SymbolTable(); + for (String s : b.getSymbolsList()) { + newSymbolTable.add(s); + } - SymbolTable symbols = new SymbolTable(); - for (String s : b.getSymbolsList()) { - symbols.add(s); - } + ArrayList facts = new ArrayList<>(); + ArrayList rules = new ArrayList<>(); + ArrayList checks = new ArrayList<>(); + + for (Schema.FactV2 fact : b.getFactsV2List()) { + Either res = Fact.deserializeV2(fact); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + facts.add(res.get()); + } + } - ArrayList facts = new ArrayList<>(); - ArrayList rules = new ArrayList<>(); - ArrayList checks = new ArrayList<>(); - - for (Schema.FactV2 fact : b.getFactsV2List()) { - Either res = Fact.deserializeV2(fact); - if (res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - facts.add(res.get()); - } - } + for (Schema.RuleV2 rule : b.getRulesV2List()) { + Either res = Rule.deserializeV2(rule); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + rules.add(res.get()); + } + } + for (Schema.CheckV2 check : b.getChecksV2List()) { + Either res = Check.deserializeV2(check); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + checks.add(res.get()); + } + } - for (Schema.RuleV2 rule : b.getRulesV2List()) { - Either res = Rule.deserializeV2(rule); - if (res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - rules.add(res.get()); - } - } + ArrayList scopes = new ArrayList<>(); + for (Schema.Scope scope : b.getScopeList()) { + Either res = Scope.deserialize(scope); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } else { + scopes.add(res.get()); + } + } + ArrayList publicKeys = new ArrayList<>(); + for (Schema.PublicKey pk : b.getPublicKeysList()) { + try { + PublicKey key = PublicKey.deserialize(pk); + publicKeys.add(key); + newSymbolTable.getPublicKeys().add(key); + } catch (Error.FormatError e) { + return Left(e); + } + } - for (Schema.CheckV2 check : b.getChecksV2List()) { - Either res = Check.deserializeV2(check); - if (res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - checks.add(res.get()); - } - } + SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); + Either res = schemaVersion.checkCompatibility(version); + if (res.isLeft()) { + Error.FormatError e = res.getLeft(); + return Left(e); + } - ArrayList scopes = new ArrayList<>(); - for (Schema.Scope scope: b.getScopeList()) { - Either res = Scope.deserialize(scope); - if(res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } else { - scopes.add(res.get()); - } - } + return Right( + new Block( + newSymbolTable, + b.getContext(), + facts, + rules, + checks, + scopes, + publicKeys, + externalKey, + version)); + } + + /** + * Deserializes a Block from a byte array + * + * @param slice + * @return + */ + public static Either fromBytes( + byte[] slice, Option externalKey) { + try { + Schema.Block data = Schema.Block.parseFrom(slice); + return Block.deserialize(data, externalKey); + } catch (InvalidProtocolBufferException e) { + return Left(new Error.FormatError.DeserializationError(e.toString())); + } + } + + public Either toBytes() { + Schema.Block b = this.serialize(); + try { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + b.writeTo(stream); + byte[] data = stream.toByteArray(); + return Right(data); + } catch (IOException e) { + return Left(new Error.FormatError.SerializationError(e.toString())); + } + } - ArrayList publicKeys = new ArrayList<>(); - for (Schema.PublicKey pk: b.getPublicKeysList()) { - try { - PublicKey key =PublicKey.deserialize(pk); - publicKeys.add(key); - symbols.publicKeys().add(key); - } catch(Error.FormatError e) { - return Left(e); - } - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); - Either res = schemaVersion.checkCompatibility(version); - if (res.isLeft()) { - Error.FormatError e = res.getLeft(); - return Left(e); - } + Block block = (Block) o; - return Right(new Block(symbols, b.getContext(), facts, rules, checks, scopes, publicKeys, externalKey, version)); - } - - /** - * Deserializes a Block from a byte array - * - * @param slice - * @return - */ - static public Either from_bytes(byte[] slice, Option externalKey) { - try { - Schema.Block data = Schema.Block.parseFrom(slice); - return Block.deserialize(data, externalKey); - } catch (InvalidProtocolBufferException e) { - return Left(new Error.FormatError.DeserializationError(e.toString())); - } + if (!Objects.equals(symbolTable, block.symbolTable)) { + return false; } - - public Either to_bytes() { - Schema.Block b = this.serialize(); - try { - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - b.writeTo(stream); - byte[] data = stream.toByteArray(); - return Right(data); - } catch (IOException e) { - return Left(new Error.FormatError.SerializationError(e.toString())); - } + if (!Objects.equals(context, block.context)) { + return false; } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Block block = (Block) o; - - if (!Objects.equals(symbols, block.symbols)) return false; - if (!Objects.equals(context, block.context)) return false; - if (!Objects.equals(facts, block.facts)) return false; - if (!Objects.equals(rules, block.rules)) return false; - if (!Objects.equals(checks, block.checks)) return false; - if (!Objects.equals(scopes, block.scopes)) return false; - if (!Objects.equals(publicKeys, block.publicKeys)) return false; - return Objects.equals(externalKey, block.externalKey); - } - - @Override - public int hashCode() { - int result = symbols != null ? symbols.hashCode() : 0; - result = 31 * result + (context != null ? context.hashCode() : 0); - result = 31 * result + (facts != null ? facts.hashCode() : 0); - result = 31 * result + (rules != null ? rules.hashCode() : 0); - result = 31 * result + (checks != null ? checks.hashCode() : 0); - result = 31 * result + (scopes != null ? scopes.hashCode() : 0); - result = 31 * result + (publicKeys != null ? publicKeys.hashCode() : 0); - result = 31 * result + (externalKey != null ? externalKey.hashCode() : 0); - return result; - } - - @Override - public String toString() { - return "Block{" + - "symbols=" + symbols + - ", context='" + context + '\'' + - ", facts=" + facts + - ", rules=" + rules + - ", checks=" + checks + - ", scopes=" + scopes + - ", publicKeys=" + publicKeys + - ", externalKey=" + externalKey + - '}'; + if (!Objects.equals(facts, block.facts)) { + return false; } + if (!Objects.equals(rules, block.rules)) { + return false; + } + if (!Objects.equals(checks, block.checks)) { + return false; + } + if (!Objects.equals(scopes, block.scopes)) { + return false; + } + if (!Objects.equals(publicKeys, block.publicKeys)) { + return false; + } + return Objects.equals(externalKey, block.externalKey); + } + + @Override + public int hashCode() { + int result = symbolTable != null ? symbolTable.hashCode() : 0; + result = 31 * result + (context != null ? context.hashCode() : 0); + result = 31 * result + (facts != null ? facts.hashCode() : 0); + result = 31 * result + (rules != null ? rules.hashCode() : 0); + result = 31 * result + (checks != null ? checks.hashCode() : 0); + result = 31 * result + (scopes != null ? scopes.hashCode() : 0); + result = 31 * result + (publicKeys != null ? publicKeys.hashCode() : 0); + result = 31 * result + (externalKey != null ? externalKey.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "Block{" + + "symbols=" + + symbolTable + + ", context='" + + context + + '\'' + + ", facts=" + + facts + + ", rules=" + + rules + + ", checks=" + + checks + + ", scopes=" + + scopes + + ", publicKeys=" + + publicKeys + + ", externalKey=" + + externalKey + + '}'; + } + + public String getContext() { + return this.context; + } + + public List getFacts() { + return Collections.unmodifiableList(facts); + } + + public List getRules() { + return Collections.unmodifiableList(rules); + } + + public List getChecks() { + return Collections.unmodifiableList(checks); + } + + + public List getPublicKeys() { + return Collections.unmodifiableList(this.publicKeys); + } + + public List getScopes() { + return Collections.unmodifiableList(scopes); + } + + public Option getExternalKey() { + return externalKey; + } + + public long getVersion() { + return version; + } } - diff --git a/src/main/java/org/biscuitsec/biscuit/token/Policy.java b/src/main/java/org/biscuitsec/biscuit/token/Policy.java index 139e4b92..79dc3b8b 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/Policy.java +++ b/src/main/java/org/biscuitsec/biscuit/token/Policy.java @@ -1,44 +1,52 @@ package org.biscuitsec.biscuit.token; -import org.biscuitsec.biscuit.token.builder.Rule; - import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import org.biscuitsec.biscuit.token.builder.Rule; -public class Policy { - public enum Kind { - Allow, - Deny, - } - - public final List queries; - public Kind kind; - - public Policy(List queries, Kind kind) { - this.queries = queries; - this.kind = kind; - } - - public Policy(Rule query, Kind kind) { - ArrayList r = new ArrayList<>(); - r.add(query); - - this.queries = r; - this.kind = kind; - } - - @Override - public String toString() { - final List qs = queries.stream().map((q) -> q.bodyToString()).collect(Collectors.toList()); - - switch(this.kind) { - case Allow: - return "allow if "+String.join(" or ", qs); - case Deny: - return "deny if "+String.join(" or ", qs); - } +public final class Policy { + public List queries() { + return queries; + } + + public Kind kind() { + return kind; + } + + public enum Kind { + ALLOW, + DENY, + } + + private final List queries; + private Kind kind; + + public Policy(List queries, Kind kind) { + this.queries = queries; + this.kind = kind; + } + + public Policy(Rule query, Kind kind) { + ArrayList r = new ArrayList<>(); + r.add(query); + + this.queries = r; + this.kind = kind; + } + + @Override + public String toString() { + final List qs = + queries.stream().map((q) -> q.bodyToString()).collect(Collectors.toList()); + + switch (this.kind) { + case ALLOW: + return "allow if " + String.join(" or ", qs); + case DENY: + return "deny if " + String.join(" or ", qs); + default: return null; } - + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/RevocationIdentifier.java b/src/main/java/org/biscuitsec/biscuit/token/RevocationIdentifier.java index 5964cd23..1a139fb8 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/RevocationIdentifier.java +++ b/src/main/java/org/biscuitsec/biscuit/token/RevocationIdentifier.java @@ -1,42 +1,43 @@ package org.biscuitsec.biscuit.token; -import org.biscuitsec.biscuit.token.builder.Utils; - import java.util.Base64; +import org.biscuitsec.biscuit.token.builder.Utils; -public class RevocationIdentifier { - private byte[] bytes; - - public RevocationIdentifier(byte[] bytes) { - this.bytes = bytes; - } - - /** - * Creates a RevocationIdentifier from base64 url (RFC4648_URLSAFE) - * @param b64url serialized revocation identifier - * @return RevocationIdentifier - */ - public static RevocationIdentifier from_b64url(String b64url) { - return new RevocationIdentifier(Base64.getDecoder().decode(b64url)); - } - - /** - * Serializes a revocation identifier as base64 url (RFC4648_URLSAFE) - * @return String - */ - public String serialize_b64url() { - return Base64.getEncoder().encodeToString(this.bytes); - } - - public String toHex() { - return Utils.byteArrayToHexString(this.bytes).toLowerCase(); - } - - public static RevocationIdentifier from_bytes(byte[] bytes) { - return new RevocationIdentifier(bytes); - } - - public byte[] getBytes() { - return this.bytes; - } +public final class RevocationIdentifier { + private byte[] bytes; + + public RevocationIdentifier(byte[] bytes) { + this.bytes = bytes; + } + + /** + * Creates a RevocationIdentifier from base64 url (RFC4648_URLSAFE) + * + * @param b64url serialized revocation identifier + * @return RevocationIdentifier + */ + public static RevocationIdentifier fromBase64Url(String b64url) { + return new RevocationIdentifier(Base64.getDecoder().decode(b64url)); + } + + /** + * Serializes a revocation identifier as base64 url (RFC4648_URLSAFE) + * + * @return String + */ + public String serializeBase64Url() { + return Base64.getEncoder().encodeToString(this.bytes); + } + + public String toHex() { + return Utils.byteArrayToHexString(this.bytes).toLowerCase(); + } + + public static RevocationIdentifier fromBytes(byte[] bytes) { + return new RevocationIdentifier(bytes); + } + + public byte[] getBytes() { + return this.bytes; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockContents.java b/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockContents.java index 812589cf..5bcfc990 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockContents.java +++ b/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockContents.java @@ -3,83 +3,106 @@ import biscuit.format.schema.Schema; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.error.Error; - import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.Objects; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.error.Error; -public class ThirdPartyBlockContents { - byte[] payload; - byte[] signature; - PublicKey publicKey; +public final class ThirdPartyBlockContents { + private byte[] payload; + private byte[] signature; + private PublicKey publicKey; - ThirdPartyBlockContents(byte[] payload, byte[] signature, PublicKey publicKey) { - this.payload = payload; - this.signature = signature; - this.publicKey = publicKey; - } + ThirdPartyBlockContents(byte[] payload, byte[] signature, PublicKey publicKey) { + this.payload = payload; + this.signature = signature; + this.publicKey = publicKey; + } - public Schema.ThirdPartyBlockContents serialize() throws Error.FormatError.SerializationError { - Schema.ThirdPartyBlockContents.Builder b = Schema.ThirdPartyBlockContents.newBuilder(); - b.setPayload(ByteString.copyFrom(this.payload)); - b.setExternalSignature(b.getExternalSignatureBuilder() - .setSignature(ByteString.copyFrom(this.signature)) - .setPublicKey(this.publicKey.serialize()) - .build()); + public Schema.ThirdPartyBlockContents serialize() throws Error.FormatError.SerializationError { + Schema.ThirdPartyBlockContents.Builder b = Schema.ThirdPartyBlockContents.newBuilder(); + b.setPayload(ByteString.copyFrom(this.payload)); + b.setExternalSignature( + b.getExternalSignatureBuilder() + .setSignature(ByteString.copyFrom(this.signature)) + .setPublicKey(this.publicKey.serialize()) + .build()); - return b.build(); - } + return b.build(); + } - static public ThirdPartyBlockContents deserialize(Schema.ThirdPartyBlockContents b) throws Error.FormatError.DeserializationError { - byte[] payload = b.getPayload().toByteArray(); - byte[] signature = b.getExternalSignature().getSignature().toByteArray(); - PublicKey publicKey = PublicKey.deserialize(b.getExternalSignature().getPublicKey()); + public static ThirdPartyBlockContents deserialize(Schema.ThirdPartyBlockContents b) + throws Error.FormatError.DeserializationError { + byte[] payload = b.getPayload().toByteArray(); + byte[] signature = b.getExternalSignature().getSignature().toByteArray(); + PublicKey publicKey = PublicKey.deserialize(b.getExternalSignature().getPublicKey()); - return new ThirdPartyBlockContents(payload, signature, publicKey); - } + return new ThirdPartyBlockContents(payload, signature, publicKey); + } - static public ThirdPartyBlockContents fromBytes(byte[] slice) throws InvalidProtocolBufferException, Error.FormatError.DeserializationError { - return ThirdPartyBlockContents.deserialize(Schema.ThirdPartyBlockContents.parseFrom(slice)); + public static ThirdPartyBlockContents fromBytes(byte[] slice) + throws InvalidProtocolBufferException, Error.FormatError.DeserializationError { + return ThirdPartyBlockContents.deserialize(Schema.ThirdPartyBlockContents.parseFrom(slice)); + } + + public byte[] toBytes() throws IOException, Error.FormatError.SerializationError { + Schema.ThirdPartyBlockContents b = this.serialize(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + b.writeTo(stream); + return stream.toByteArray(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } + if (o == null || getClass() != o.getClass()) { + return false; + } + + ThirdPartyBlockContents that = (ThirdPartyBlockContents) o; - public byte[] toBytes() throws IOException, Error.FormatError.SerializationError { - Schema.ThirdPartyBlockContents b = this.serialize(); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - b.writeTo(stream); - return stream.toByteArray(); + if (!Arrays.equals(payload, that.payload)) { + return false; } + if (!Arrays.equals(signature, that.signature)) { + return false; + } + return Objects.equals(publicKey, that.publicKey); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public int hashCode() { + int result = Arrays.hashCode(payload); + result = 31 * result + Arrays.hashCode(signature); + result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0); + return result; + } - ThirdPartyBlockContents that = (ThirdPartyBlockContents) o; + @Override + public String toString() { + return "ThirdPartyBlockContents{" + + "payload=" + + Arrays.toString(payload) + + ", signature=" + + Arrays.toString(signature) + + ", publicKey=" + + publicKey + + '}'; + } - if (!Arrays.equals(payload, that.payload)) return false; - if (!Arrays.equals(signature, that.signature)) return false; - return Objects.equals(publicKey, that.publicKey); - } + public byte[] getPayload() { + return payload; + } - @Override - public int hashCode() { - int result = Arrays.hashCode(payload); - result = 31 * result + Arrays.hashCode(signature); - result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0); - return result; - } + public byte[] getSignature() { + return signature; + } - @Override - public String toString() { - return "ThirdPartyBlockContents{" + - "payload=" + Arrays.toString(payload) + - ", signature=" + Arrays.toString(signature) + - ", publicKey=" + publicKey + - '}'; - } + public PublicKey getPublicKey() { + return publicKey; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockRequest.java b/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockRequest.java index 3ab63607..ca30319d 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockRequest.java +++ b/src/main/java/org/biscuitsec/biscuit/token/ThirdPartyBlockRequest.java @@ -4,6 +4,12 @@ import com.google.protobuf.InvalidProtocolBufferException; import io.vavr.control.Either; import io.vavr.control.Option; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SignatureException; +import java.util.Objects; import org.biscuitsec.biscuit.crypto.BlockSignatureBuffer; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.crypto.Signer; @@ -11,79 +17,80 @@ import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.builder.Block; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.security.*; -import java.util.Objects; +public final class ThirdPartyBlockRequest { + private final PublicKey previousKey; -public class ThirdPartyBlockRequest { - PublicKey previousKey; + ThirdPartyBlockRequest(PublicKey previousKey) { + this.previousKey = previousKey; + } - ThirdPartyBlockRequest(PublicKey previousKey) { - this.previousKey = previousKey; + public Either createBlock( + final Signer externalSigner, Block blockBuilder) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + SymbolTable symbolTable = new SymbolTable(); + org.biscuitsec.biscuit.token.Block block = + blockBuilder.build(symbolTable, Option.some(externalSigner.getPublicKey())); + + Either res = block.toBytes(); + if (res.isLeft()) { + return Either.left(res.getLeft()); } - public Either createBlock(final Signer externalSigner, Block blockBuilder) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - SymbolTable symbols = new SymbolTable(); - org.biscuitsec.biscuit.token.Block block = blockBuilder.build(symbols, Option.some(externalSigner.public_key())); + byte[] serializedBlock = res.get(); + byte[] payload = BlockSignatureBuffer.getBufferSignature(this.previousKey, serializedBlock); + byte[] signature = externalSigner.sign(payload); - Either res = block.to_bytes(); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } + PublicKey publicKey = externalSigner.getPublicKey(); - byte[] serializedBlock = res.get(); - byte[] payload = BlockSignatureBuffer.getBufferSignature(this.previousKey, serializedBlock); - byte[] signature = externalSigner.sign(payload); + return Either.right(new ThirdPartyBlockContents(serializedBlock, signature, publicKey)); + } - PublicKey publicKey = externalSigner.public_key(); + public Schema.ThirdPartyBlockRequest serialize() throws Error.FormatError.SerializationError { + Schema.ThirdPartyBlockRequest.Builder b = Schema.ThirdPartyBlockRequest.newBuilder(); + b.setPreviousKey(this.previousKey.serialize()); - return Either.right(new ThirdPartyBlockContents(serializedBlock, signature, publicKey)); - } + return b.build(); + } - public Schema.ThirdPartyBlockRequest serialize() throws Error.FormatError.SerializationError { - Schema.ThirdPartyBlockRequest.Builder b = Schema.ThirdPartyBlockRequest.newBuilder(); - b.setPreviousKey(this.previousKey.serialize()); + public static ThirdPartyBlockRequest deserialize(Schema.ThirdPartyBlockRequest b) + throws Error.FormatError.DeserializationError { + PublicKey previousKey = PublicKey.deserialize(b.getPreviousKey()); + return new ThirdPartyBlockRequest(previousKey); + } - return b.build(); - } + public static ThirdPartyBlockRequest fromBytes(byte[] slice) + throws InvalidProtocolBufferException, Error.FormatError.DeserializationError { + return ThirdPartyBlockRequest.deserialize(Schema.ThirdPartyBlockRequest.parseFrom(slice)); + } - static public ThirdPartyBlockRequest deserialize(Schema.ThirdPartyBlockRequest b) throws Error.FormatError.DeserializationError { - PublicKey previousKey = PublicKey.deserialize(b.getPreviousKey()); - return new ThirdPartyBlockRequest(previousKey); - } + public byte[] toBytes() throws IOException, Error.FormatError.SerializationError { + Schema.ThirdPartyBlockRequest b = this.serialize(); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + b.writeTo(stream); + return stream.toByteArray(); + } - static public ThirdPartyBlockRequest fromBytes(byte[] slice) throws InvalidProtocolBufferException, Error.FormatError.DeserializationError { - return ThirdPartyBlockRequest.deserialize(Schema.ThirdPartyBlockRequest.parseFrom(slice)); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public byte[] toBytes() throws IOException, Error.FormatError.SerializationError { - Schema.ThirdPartyBlockRequest b = this.serialize(); - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - b.writeTo(stream); - return stream.toByteArray(); + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - ThirdPartyBlockRequest that = (ThirdPartyBlockRequest) o; + ThirdPartyBlockRequest that = (ThirdPartyBlockRequest) o; - return Objects.equals(previousKey, that.previousKey); - } + return Objects.equals(previousKey, that.previousKey); + } - @Override - public int hashCode() { - return previousKey != null ? previousKey.hashCode() : 0; - } + @Override + public int hashCode() { + return previousKey != null ? previousKey.hashCode() : 0; + } - @Override - public String toString() { - return "ThirdPartyBlockRequest{" + - "previousKey=" + previousKey + - '}'; - } + @Override + public String toString() { + return "ThirdPartyBlockRequest{previousKey=" + previousKey + '}'; + } } - diff --git a/src/main/java/org/biscuitsec/biscuit/token/UnverifiedBiscuit.java b/src/main/java/org/biscuitsec/biscuit/token/UnverifiedBiscuit.java index bff22e35..dc49be95 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/UnverifiedBiscuit.java +++ b/src/main/java/org/biscuitsec/biscuit/token/UnverifiedBiscuit.java @@ -1,336 +1,364 @@ package org.biscuitsec.biscuit.token; import biscuit.format.schema.Schema.PublicKey.Algorithm; +import io.vavr.Tuple2; +import io.vavr.control.Either; +import io.vavr.control.Option; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.SignatureException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.stream.Collectors; import org.biscuitsec.biscuit.crypto.BlockSignatureBuffer; import org.biscuitsec.biscuit.crypto.KeyDelegate; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.Check; +import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.format.ExternalSignature; import org.biscuitsec.biscuit.token.format.SerializedBiscuit; -import io.vavr.Tuple2; -import io.vavr.control.Either; -import io.vavr.control.Option; -import org.biscuitsec.biscuit.datalog.Check; -import org.biscuitsec.biscuit.datalog.SymbolTable; - -import java.security.*; -import java.util.*; -import java.util.stream.Collectors; /** - * UnverifiedBiscuit auth token. UnverifiedBiscuit means it's deserialized without checking signatures. + * UnverifiedBiscuit auth token. UnverifiedBiscuit means it's deserialized without checking + * signatures. */ public class UnverifiedBiscuit { - final Block authority; - final List blocks; - final SymbolTable symbols; - final SerializedBiscuit serializedBiscuit; - final List revocation_ids; - - UnverifiedBiscuit(Block authority, List blocks, SymbolTable symbols, SerializedBiscuit serializedBiscuit, - List revocation_ids) { - this.authority = authority; - this.blocks = blocks; - this.symbols = symbols; - this.serializedBiscuit = serializedBiscuit; - this.revocation_ids = revocation_ids; + protected final Block authority; + protected final List blocks; + protected final SymbolTable symbolTable; + protected final SerializedBiscuit serializedBiscuit; + protected final List revocationIds; + + UnverifiedBiscuit( + Block authority, + List blocks, + SymbolTable symbolTable, + SerializedBiscuit serializedBiscuit, + List revocationIds) { + this.authority = authority; + this.blocks = blocks; + this.symbolTable = symbolTable; + this.serializedBiscuit = serializedBiscuit; + this.revocationIds = revocationIds; + } + + /** + * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string + * + *

This method uses the default symbol table + * + * @param data + * @return Biscuit + */ + public static UnverifiedBiscuit fromBase64Url(String data) throws Error { + return UnverifiedBiscuit.fromBytes(Base64.getUrlDecoder().decode(data)); + } + + /** + * Deserializes a Biscuit token from a byte array + * + *

This method uses the default symbol table + * + * @param data + * @return + */ + public static UnverifiedBiscuit fromBytes(byte[] data) throws Error { + return UnverifiedBiscuit.fromBytesWithSymbols(data, defaultSymbolTable()); + } + + /** + * Deserializes a UnverifiedBiscuit from a byte array + * + * @param data + * @return UnverifiedBiscuit + */ + public static UnverifiedBiscuit fromBytesWithSymbols(byte[] data, SymbolTable symbolTable) + throws Error { + SerializedBiscuit ser = SerializedBiscuit.deserializeUnsafe(data); + return UnverifiedBiscuit.fromSerializedBiscuit(ser, symbolTable); + } + + /** + * Fills a UnverifiedBiscuit structure from a deserialized token + * + * @return UnverifiedBiscuit + */ + private static UnverifiedBiscuit fromSerializedBiscuit(SerializedBiscuit ser, SymbolTable symbolTable) + throws Error { + Tuple2> t = ser.extractBlocks(symbolTable); + Block authority = t._1; + ArrayList blocks = t._2; + + List revocationIds = ser.revocationIdentifiers(); + + return new UnverifiedBiscuit(authority, blocks, symbolTable, ser, revocationIds); + } + + /** + * Serializes a token to a byte array + * + * @return + */ + public byte[] serialize() throws Error.FormatError.SerializationError { + return this.serializedBiscuit.serialize(); + } + + /** + * Serializes a token to base 64 url String using RFC4648_URLSAFE + * + * @return String + * @throws Error.FormatError.SerializationError + */ + public String serializeBase64Url() throws Error.FormatError.SerializationError { + return Base64.getUrlEncoder().encodeToString(serialize()); + } + + /** + * Creates a Block builder + * + * @return + */ + public org.biscuitsec.biscuit.token.builder.Block createBlock() { + return new org.biscuitsec.biscuit.token.builder.Block(); + } + + /** + * Generates a new token from an existing one and a new block + * + * @param block new block (should be generated from a Block builder) + * @param algorithm algorithm to use for the ephemeral key pair + * @return + */ + public UnverifiedBiscuit attenuate( + org.biscuitsec.biscuit.token.builder.Block block, Algorithm algorithm) throws Error { + SecureRandom rng = new SecureRandom(); + KeyPair keypair = KeyPair.generate(algorithm, rng); + SymbolTable builderSymbols = new SymbolTable(this.symbolTable); + return attenuate(rng, keypair, block.build(builderSymbols)); + } + + public UnverifiedBiscuit attenuate( + final SecureRandom rng, + final KeyPair keypair, + org.biscuitsec.biscuit.token.builder.Block block) + throws Error { + SymbolTable builderSymbols = new SymbolTable(this.symbolTable); + return attenuate(rng, keypair, block.build(builderSymbols)); + } + + /** + * Generates a new token from an existing one and a new block + * + * @param rng random number generator + * @param keypair ephemeral key pair + * @param block new block (should be generated from a Block builder) + * @return + */ + public UnverifiedBiscuit attenuate(final SecureRandom rng, final KeyPair keypair, Block block) + throws Error { + UnverifiedBiscuit copiedBiscuit = this.copy(); + + if (!copiedBiscuit.symbolTable.disjoint(block.getSymbolTable())) { + throw new Error.SymbolTableOverlap(); } - /** - * Deserializes a Biscuit token from a base64 url (RFC4648_URLSAFE) string - *

- * This method uses the default symbol table - * - * @param data - * @return Biscuit - */ - static public UnverifiedBiscuit from_b64url(String data) throws Error { - return UnverifiedBiscuit.from_bytes(Base64.getUrlDecoder().decode(data)); + Either containerRes = + copiedBiscuit.serializedBiscuit.append(keypair, block, Option.none()); + if (containerRes.isLeft()) { + throw containerRes.getLeft(); } - /** - * Deserializes a Biscuit token from a byte array - *

- * This method uses the default symbol table - * - * @param data - * @return - */ - static public UnverifiedBiscuit from_bytes(byte[] data) throws Error { - return UnverifiedBiscuit.from_bytes_with_symbols(data, default_symbol_table()); + SymbolTable symbols = new SymbolTable(copiedBiscuit.symbolTable); + for (String s : block.getSymbolTable().symbols()) { + symbols.add(s); } - /** - * Deserializes a UnverifiedBiscuit from a byte array - * - * @param data - * @return UnverifiedBiscuit - */ - static public UnverifiedBiscuit from_bytes_with_symbols(byte[] data, SymbolTable symbols) throws Error { - SerializedBiscuit ser = SerializedBiscuit.unsafe_deserialize(data); - return UnverifiedBiscuit.from_serialized_biscuit(ser, symbols); + ArrayList blocks = new ArrayList<>(); + for (Block b : copiedBiscuit.blocks) { + blocks.add(b); } + blocks.add(block); + SerializedBiscuit container = containerRes.get(); - /** - * Fills a UnverifiedBiscuit structure from a deserialized token - * - * @return UnverifiedBiscuit - */ - static private UnverifiedBiscuit from_serialized_biscuit(SerializedBiscuit ser, SymbolTable symbols) throws Error { - Tuple2> t = ser.extractBlocks(symbols); - Block authority = t._1; - ArrayList blocks = t._2; + List revocationIds = container.revocationIdentifiers(); - List revocation_ids = ser.revocation_identifiers(); + return new UnverifiedBiscuit( + copiedBiscuit.authority, blocks, symbols, container, revocationIds); + } - return new UnverifiedBiscuit(authority, blocks, symbols, ser, revocation_ids); - } - - /** - * Serializes a token to a byte array - * - * @return - */ - public byte[] serialize() throws Error.FormatError.SerializationError { - return this.serializedBiscuit.serialize(); - } + // FIXME: attenuate 3rd Party - /** - * Serializes a token to base 64 url String using RFC4648_URLSAFE - * - * @return String - * @throws Error.FormatError.SerializationError - */ - public String serialize_b64url() throws Error.FormatError.SerializationError { - return Base64.getUrlEncoder().encodeToString(serialize()); - } + public List revocationIdentifiers() { + return this.revocationIds.stream() + .map(RevocationIdentifier::fromBytes) + .collect(Collectors.toList()); + } - /** - * Creates a Block builder - * - * @return - */ - public org.biscuitsec.biscuit.token.builder.Block create_block() { - return new org.biscuitsec.biscuit.token.builder.Block(); - } + public List> getChecks() { + ArrayList> l = new ArrayList<>(); + l.add(new ArrayList<>(this.authority.getChecks())); - /** - * Generates a new token from an existing one and a new block - * - * @param block new block (should be generated from a Block builder) - * @param algorithm algorithm to use for the ephemeral key pair - * @return - */ - public UnverifiedBiscuit attenuate(org.biscuitsec.biscuit.token.builder.Block block, Algorithm algorithm) throws Error { - SecureRandom rng = new SecureRandom(); - KeyPair keypair = KeyPair.generate(algorithm, rng); - SymbolTable builderSymbols = new SymbolTable(this.symbols); - return attenuate(rng, keypair, block.build(builderSymbols)); + for (Block b : this.blocks) { + l.add(new ArrayList<>(b.getChecks())); } - public UnverifiedBiscuit attenuate(final SecureRandom rng, final KeyPair keypair, org.biscuitsec.biscuit.token.builder.Block block) throws Error { - SymbolTable builderSymbols = new SymbolTable(this.symbols); - return attenuate(rng, keypair, block.build(builderSymbols)); - } + return l; + } - /** - * Generates a new token from an existing one and a new block - * - * @param rng random number generator - * @param keypair ephemeral key pair - * @param block new block (should be generated from a Block builder) - * @return - */ - public UnverifiedBiscuit attenuate(final SecureRandom rng, final KeyPair keypair, Block block) throws Error { - UnverifiedBiscuit copiedBiscuit = this.copy(); - - if (!Collections.disjoint(copiedBiscuit.symbols.symbols, block.symbols.symbols)) { - throw new Error.SymbolTableOverlap(); - } - - Either containerRes = copiedBiscuit.serializedBiscuit.append(keypair, block, Option.none()); - if (containerRes.isLeft()) { - throw containerRes.getLeft(); - } - SerializedBiscuit container = containerRes.get(); - - SymbolTable symbols = new SymbolTable(copiedBiscuit.symbols); - for (String s : block.symbols.symbols) { - symbols.add(s); - } - - ArrayList blocks = new ArrayList<>(); - for (Block b : copiedBiscuit.blocks) { - blocks.add(b); - } - blocks.add(block); - - List revocation_ids = container.revocation_identifiers(); - - return new UnverifiedBiscuit(copiedBiscuit.authority, blocks, symbols, container, revocation_ids); + public List> getContext() { + ArrayList> res = new ArrayList<>(); + if (this.authority.getContext().isEmpty()) { + res.add(Option.none()); + } else { + res.add(Option.some(this.authority.getContext())); } - //FIXME: attenuate 3rd Party - public List revocation_identifiers() { - return this.revocation_ids.stream() - .map(RevocationIdentifier::from_bytes) - .collect(Collectors.toList()); + for (Block b : this.blocks) { + if (b.getContext().isEmpty()) { + res.add(Option.none()); + } else { + res.add(Option.some(b.getContext())); + } } - public List> checks() { - ArrayList> l = new ArrayList<>(); - l.add(new ArrayList<>(this.authority.checks)); - - for (Block b : this.blocks) { - l.add(new ArrayList<>(b.checks)); - } - - return l; + return res; + } + + public Option getRootKeyId() { + return this.serializedBiscuit.getRootKeyId(); + } + + /** Generates a third party block request from a token */ + public ThirdPartyBlockRequest thirdPartyRequest() { + PublicKey previousKey; + if (this.serializedBiscuit.getBlocks().isEmpty()) { + previousKey = this.serializedBiscuit.getAuthority().getKey(); + } else { + previousKey = + this.serializedBiscuit + .getBlocks() + .get(this.serializedBiscuit.getBlocks().size() - 1) + .getKey(); } - public List> context() { - ArrayList> res = new ArrayList<>(); - if (this.authority.context.isEmpty()) { - res.add(Option.none()); - } else { - res.add(Option.some(this.authority.context)); - } - - for (Block b : this.blocks) { - if (b.context.isEmpty()) { - res.add(Option.none()); - } else { - res.add(Option.some(b.context)); - } - } - - return res; + return new ThirdPartyBlockRequest(previousKey); + } + + /** Generates a third party block request from a token */ + public UnverifiedBiscuit appendThirdPartyBlock( + PublicKey externalKey, ThirdPartyBlockContents blockResponse) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + PublicKey previousKey; + if (this.serializedBiscuit.getBlocks().isEmpty()) { + previousKey = this.serializedBiscuit.getAuthority().getKey(); + } else { + previousKey = + this.serializedBiscuit + .getBlocks() + .get(this.serializedBiscuit.getBlocks().size() - 1) + .getKey(); } - - public Option root_key_id() { - return this.serializedBiscuit.root_key_id; + KeyPair nextKeyPair = KeyPair.generate(previousKey.getAlgorithm()); + byte[] payload = + BlockSignatureBuffer.getBufferSignature(previousKey, blockResponse.getPayload()); + if (!KeyPair.verify(externalKey, payload, blockResponse.getSignature())) { + throw new Error.FormatError.Signature.InvalidSignature( + "signature error: Verification equation was not satisfied"); } - /** - * Generates a third party block request from a token - */ - public ThirdPartyBlockRequest thirdPartyRequest() { - PublicKey previousKey; - if(this.serializedBiscuit.blocks.isEmpty()) { - previousKey = this.serializedBiscuit.authority.key; - } else { - previousKey = this.serializedBiscuit.blocks.get(this.serializedBiscuit.blocks.size() - 1).key; - } - - return new ThirdPartyBlockRequest(previousKey); + Either res = + Block.fromBytes(blockResponse.getPayload(), Option.some(externalKey)); + if (res.isLeft()) { + throw res.getLeft(); } + Block block = res.get(); - /** - * Generates a third party block request from a token - */ - public UnverifiedBiscuit appendThirdPartyBlock(PublicKey externalKey, ThirdPartyBlockContents blockResponse) - throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - PublicKey previousKey; - if(this.serializedBiscuit.blocks.isEmpty()) { - previousKey = this.serializedBiscuit.authority.key; - } else { - previousKey = this.serializedBiscuit.blocks.get(this.serializedBiscuit.blocks.size() - 1).key; - } - KeyPair nextKeyPair = KeyPair.generate(previousKey.algorithm); - byte[] payload = BlockSignatureBuffer.getBufferSignature(previousKey, blockResponse.payload); - if (!KeyPair.verify(externalKey, payload, blockResponse.signature)) { - throw new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied"); - } - - Either res = Block.from_bytes(blockResponse.payload, Option.some(externalKey)); - if(res.isLeft()) { - throw res.getLeft(); - } - - Block block = res.get(); - - ExternalSignature externalSignature = new ExternalSignature(externalKey, blockResponse.signature); - - UnverifiedBiscuit copiedBiscuit = this.copy(); + ExternalSignature externalSignature = + new ExternalSignature(externalKey, blockResponse.getSignature()); - Either containerRes = copiedBiscuit.serializedBiscuit.append(nextKeyPair, block, Option.some(externalSignature)); - if (containerRes.isLeft()) { - throw containerRes.getLeft(); - } + UnverifiedBiscuit copiedBiscuit = this.copy(); - SerializedBiscuit container = containerRes.get(); - - SymbolTable symbols = new SymbolTable(copiedBiscuit.symbols); - - ArrayList blocks = new ArrayList<>(); - for (Block b : copiedBiscuit.blocks) { - blocks.add(b); - } - blocks.add(block); - - List revocation_ids = container.revocation_identifiers(); - return new UnverifiedBiscuit(copiedBiscuit.authority, blocks, symbols, container, revocation_ids); + Either containerRes = + copiedBiscuit.serializedBiscuit.append(nextKeyPair, block, Option.some(externalSignature)); + if (containerRes.isLeft()) { + throw containerRes.getLeft(); } - /** - * Prints a token's content - */ - public String print() { - StringBuilder s = new StringBuilder(); - s.append("UnverifiedBiscuit {\n\tsymbols: "); - s.append(this.symbols.getAllSymbols()); - s.append("\n\tauthority: "); - s.append(this.authority.print(this.symbols)); - s.append("\n\tblocks: [\n"); - for (Block b : this.blocks) { - s.append("\t\t"); - s.append(b.print(this.symbols)); - s.append("\n"); - } - s.append("\t]\n}"); - - return s.toString(); - } + SerializedBiscuit container = containerRes.get(); - /** - * Default symbols list - */ - static public SymbolTable default_symbol_table() { - return new SymbolTable(); - } + SymbolTable symbols = new SymbolTable(copiedBiscuit.symbolTable); - @Override - protected Object clone() throws CloneNotSupportedException { - return super.clone(); + ArrayList blocks = new ArrayList<>(); + for (Block b : copiedBiscuit.blocks) { + blocks.add(b); } - - public UnverifiedBiscuit copy() throws Error { - return UnverifiedBiscuit.from_bytes(this.serialize()); + blocks.add(block); + + List revocationIds = container.revocationIdentifiers(); + return new UnverifiedBiscuit( + copiedBiscuit.authority, blocks, symbols, container, revocationIds); + } + + /** Prints a token's content */ + public String print() { + StringBuilder s = new StringBuilder(); + s.append("UnverifiedBiscuit {\n\tsymbols: "); + s.append(this.symbolTable.getAllSymbols()); + s.append("\n\tauthority: "); + s.append(this.authority.print(this.symbolTable)); + s.append("\n\tblocks: [\n"); + for (Block b : this.blocks) { + s.append("\t\t"); + s.append(b.print(this.symbolTable)); + s.append("\n"); } - - public Biscuit verify(PublicKey publicKey) throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { - SerializedBiscuit serializedBiscuit = this.serializedBiscuit; - var result = serializedBiscuit.verify(publicKey); - if (result.isLeft()) { - throw result.getLeft(); - } - return Biscuit.from_serialized_biscuit(serializedBiscuit, this.symbols); + s.append("\t]\n}"); + + return s.toString(); + } + + /** Default symbols list */ + public static SymbolTable defaultSymbolTable() { + return new SymbolTable(); + } + + @Override + protected Object clone() throws CloneNotSupportedException { + return super.clone(); + } + + public UnverifiedBiscuit copy() throws Error { + return UnverifiedBiscuit.fromBytes(this.serialize()); + } + + public Biscuit verify(PublicKey publicKey) + throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { + SerializedBiscuit serializedBiscuit = this.serializedBiscuit; + var result = serializedBiscuit.verify(publicKey); + if (result.isLeft()) { + throw result.getLeft(); } + return Biscuit.fromSerializedBiscuit(serializedBiscuit, this.symbolTable); + } - public Biscuit verify(KeyDelegate delegate) throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { - SerializedBiscuit serializedBiscuit = this.serializedBiscuit; + public Biscuit verify(KeyDelegate delegate) + throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { + SerializedBiscuit serializedBiscuit = this.serializedBiscuit; - Option root = delegate.root_key(serializedBiscuit.root_key_id); - if(root.isEmpty()) { - throw new InvalidKeyException("unknown root key id"); - } + Option root = delegate.getRootKey(serializedBiscuit.getRootKeyId()); + if (root.isEmpty()) { + throw new InvalidKeyException("unknown root key id"); + } - var result = serializedBiscuit.verify(root.get()); - if (result.isLeft()) { - throw result.getLeft(); - } - return Biscuit.from_serialized_biscuit(serializedBiscuit, this.symbols); + var result = serializedBiscuit.verify(root.get()); + if (result.isLeft()) { + throw result.getLeft(); } + return Biscuit.fromSerializedBiscuit(serializedBiscuit, this.symbolTable); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Biscuit.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Biscuit.java index 1bcdc99e..1b5f897d 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Biscuit.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Biscuit.java @@ -1,182 +1,206 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.SchemaVersion; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.token.Block; +import static org.biscuitsec.biscuit.token.UnverifiedBiscuit.defaultSymbolTable; + import io.vavr.Tuple2; import io.vavr.control.Either; import io.vavr.control.Option; -import org.biscuitsec.biscuit.token.builder.parser.Parser; - import java.security.SecureRandom; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.SchemaVersion; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.token.Block; +import org.biscuitsec.biscuit.token.builder.parser.Parser; -import static org.biscuitsec.biscuit.token.UnverifiedBiscuit.default_symbol_table; - -public class Biscuit { - SecureRandom rng; - org.biscuitsec.biscuit.crypto.Signer root; - String context; - List facts; - List rules; - List checks; - List scopes; - Option root_key_id; - - public Biscuit(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root) { - this.rng = rng; - this.root = root; - this.context = ""; - this.facts = new ArrayList<>(); - this.rules = new ArrayList<>(); - this.checks = new ArrayList<>(); - this.scopes = new ArrayList<>(); - this.root_key_id = Option.none(); +public final class Biscuit { + private SecureRandom rng; + private org.biscuitsec.biscuit.crypto.Signer root; + private String context; + private List facts; + private List rules; + private List checks; + private List scopes; + private Option rootKeyId; + + public Biscuit(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root) { + this.rng = rng; + this.root = root; + this.context = ""; + this.facts = new ArrayList<>(); + this.rules = new ArrayList<>(); + this.checks = new ArrayList<>(); + this.scopes = new ArrayList<>(); + this.rootKeyId = Option.none(); + } + + public Biscuit( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + Option rootKeyId) { + this.rng = rng; + this.root = root; + this.context = ""; + this.facts = new ArrayList<>(); + this.rules = new ArrayList<>(); + this.checks = new ArrayList<>(); + this.scopes = new ArrayList<>(); + this.rootKeyId = rootKeyId; + } + + public Biscuit( + final SecureRandom rng, + final org.biscuitsec.biscuit.crypto.Signer root, + Option rootKeyId, + org.biscuitsec.biscuit.token.builder.Block block) { + this.rng = rng; + this.root = root; + this.rootKeyId = rootKeyId; + this.context = block.context(); + this.facts = block.facts(); + this.rules = block.rules(); + this.checks = block.checks(); + this.scopes = block.scopes(); + } + + public Biscuit addAuthorityFact(org.biscuitsec.biscuit.token.builder.Fact f) + throws Error.Language { + f.validate(); + this.facts.add(f); + return this; + } + + public Biscuit addAuthorityFact(String s) throws Error.Parser, Error.Language { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.fact(s); + + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Biscuit(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, Option root_key_id) { - this.rng = rng; - this.root = root; - this.context = ""; - this.facts = new ArrayList<>(); - this.rules = new ArrayList<>(); - this.checks = new ArrayList<>(); - this.scopes = new ArrayList<>(); - this.root_key_id = root_key_id; - } + Tuple2 t = res.get(); - public Biscuit(final SecureRandom rng, final org.biscuitsec.biscuit.crypto.Signer root, Option root_key_id, org.biscuitsec.biscuit.token.builder.Block block) { - this.rng = rng; - this.root = root; - this.root_key_id = root_key_id; - this.context = block.context; - this.facts = block.facts; - this.rules = block.rules; - this.checks = block.checks; - this.scopes = block.scopes; - } + return addAuthorityFact(t._2); + } + + public Biscuit addAuthorityRule(org.biscuitsec.biscuit.token.builder.Rule rule) { + this.rules.add(rule); + return this; + } + + public Biscuit addAuthorityRule(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.rule(s); - public Biscuit add_authority_fact(org.biscuitsec.biscuit.token.builder.Fact f) throws Error.Language { - f.validate(); - this.facts.add(f); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Biscuit add_authority_fact(String s) throws Error.Parser, Error.Language { - Either> res = - Parser.fact(s); + Tuple2 t = res.get(); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + return addAuthorityRule(t._2); + } - Tuple2 t = res.get(); + public Biscuit addAuthorityCheck(org.biscuitsec.biscuit.token.builder.Check c) { + this.checks.add(c); + return this; + } - return add_authority_fact(t._2); - } + public Biscuit addAuthorityCheck(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.check(s); - public Biscuit add_authority_rule(org.biscuitsec.biscuit.token.builder.Rule rule) { - this.rules.add(rule); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Biscuit add_authority_rule(String s) throws Error.Parser { - Either> res = - Parser.rule(s); + Tuple2 t = res.get(); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + return addAuthorityCheck(t._2); + } - Tuple2 t = res.get(); + public Biscuit setContext(String context) { + this.context = context; + return this; + } - return add_authority_rule(t._2); - } + public Biscuit addScope(org.biscuitsec.biscuit.token.builder.Scope scope) { + this.scopes.add(scope); + return this; + } - public Biscuit add_authority_check(org.biscuitsec.biscuit.token.builder.Check c) { - this.checks.add(c); - return this; - } + public void setRootKeyId(Integer i) { + this.rootKeyId = Option.some(i); + } - public Biscuit add_authority_check(String s) throws Error.Parser { - Either> res = - Parser.check(s); + public org.biscuitsec.biscuit.token.Biscuit build() throws Error { + return build(defaultSymbolTable()); + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + private org.biscuitsec.biscuit.token.Biscuit build(SymbolTable symbolTable) throws Error { + final int symbolStart = symbolTable.currentOffset(); + final int publicKeyStart = symbolTable.currentPublicKeyOffset(); - Tuple2 t = res.get(); - - return add_authority_check(t._2); + List facts = new ArrayList<>(); + for (Fact f : this.facts) { + facts.add(f.convert(symbolTable)); } - - public Biscuit set_context(String context) { - this.context = context; - return this; + List rules = new ArrayList<>(); + for (Rule r : this.rules) { + rules.add(r.convert(symbolTable)); } - - public Biscuit add_scope(org.biscuitsec.biscuit.token.builder.Scope scope) { - this.scopes.add(scope); - return this; + List checks = new ArrayList<>(); + for (Check c : this.checks) { + checks.add(c.convert(symbolTable)); } - - public void set_root_key_id(Integer i) { - this.root_key_id = Option.some(i); + List scopes = new ArrayList<>(); + for (Scope s : this.scopes) { + scopes.add(s.convert(symbolTable)); } + SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); + + SymbolTable blockSymbols = new SymbolTable(); - public org.biscuitsec.biscuit.token.Biscuit build() throws Error { - return build(default_symbol_table()); + for (int i = symbolStart; i < symbolTable.symbols().size(); i++) { + blockSymbols.add(symbolTable.symbols().get(i)); } - private org.biscuitsec.biscuit.token.Biscuit build(SymbolTable symbols) throws Error { - int symbol_start = symbols.currentOffset(); - int publicKeyStart = symbols.currentPublicKeyOffset(); - - List facts = new ArrayList<>(); - for(Fact f: this.facts) { - facts.add(f.convert(symbols)); - } - List rules = new ArrayList<>(); - for(Rule r: this.rules) { - rules.add(r.convert(symbols)); - } - List checks = new ArrayList<>(); - for(Check c: this.checks) { - checks.add(c.convert(symbols)); - } - List scopes = new ArrayList<>(); - for(Scope s: this.scopes) { - scopes.add(s.convert(symbols)); - } - SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); - - SymbolTable block_symbols = new SymbolTable(); - - for (int i = symbol_start; i < symbols.symbols.size(); i++) { - block_symbols.add(symbols.symbols.get(i)); - } - - List publicKeys = new ArrayList<>(); - for (int i = publicKeyStart; i < symbols.currentPublicKeyOffset(); i++) { - publicKeys.add(symbols.publicKeys().get(i)); - } - - Block authority_block = new Block(block_symbols, context, facts, rules, - checks, scopes, publicKeys, Option.none(), schemaVersion.version()); - - if (this.root_key_id.isDefined()) { - return org.biscuitsec.biscuit.token.Biscuit.make(this.rng, this.root, this.root_key_id.get(), authority_block); - } else { - return org.biscuitsec.biscuit.token.Biscuit.make(this.rng, this.root, authority_block); - } + List publicKeys = new ArrayList<>(); + for (int i = publicKeyStart; i < symbolTable.currentPublicKeyOffset(); i++) { + publicKeys.add(symbolTable.getPublicKeys().get(i)); } - public Biscuit add_right(String resource, String right) throws Error.Language { - return this.add_authority_fact(Utils.fact("right", Arrays.asList(Utils.string(resource), Utils.s(right)))); + Block authorityBlock = + new Block( + blockSymbols, + context, + facts, + rules, + checks, + scopes, + publicKeys, + Option.none(), + schemaVersion.version()); + + if (this.rootKeyId.isDefined()) { + return org.biscuitsec.biscuit.token.Biscuit.make( + this.rng, this.root, this.rootKeyId.get(), authorityBlock); + } else { + return org.biscuitsec.biscuit.token.Biscuit.make(this.rng, this.root, authorityBlock); } + } + + public Biscuit addRight(String resource, String right) throws Error.Language { + return this.addAuthorityFact( + Utils.fact("right", Arrays.asList(Utils.string(resource), Utils.str(right)))); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java index 344f51c9..1e3366ec 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Block.java @@ -1,227 +1,290 @@ package org.biscuitsec.biscuit.token.builder; +import static org.biscuitsec.biscuit.datalog.Check.Kind.ONE; +import static org.biscuitsec.biscuit.token.UnverifiedBiscuit.defaultSymbolTable; +import static org.biscuitsec.biscuit.token.builder.Utils.constrainedRule; +import static org.biscuitsec.biscuit.token.builder.Utils.date; +import static org.biscuitsec.biscuit.token.builder.Utils.pred; +import static org.biscuitsec.biscuit.token.builder.Utils.rule; +import static org.biscuitsec.biscuit.token.builder.Utils.str; +import static org.biscuitsec.biscuit.token.builder.Utils.string; +import static org.biscuitsec.biscuit.token.builder.Utils.var; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.datalog.SchemaVersion; import io.vavr.Tuple2; import io.vavr.control.Either; import io.vavr.control.Option; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Objects; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.SchemaVersion; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.builder.parser.Parser; -import static org.biscuitsec.biscuit.datalog.Check.Kind.One; -import static org.biscuitsec.biscuit.token.UnverifiedBiscuit.default_symbol_table; -import static org.biscuitsec.biscuit.token.builder.Utils.*; +public final class Block { + private String context; + private List facts; + private List rules; + private List checks; + private List scopes; + + public Block() { + this.context = ""; + this.facts = new ArrayList<>(); + this.rules = new ArrayList<>(); + this.checks = new ArrayList<>(); + this.scopes = new ArrayList<>(); + } + + public Block addFact(org.biscuitsec.biscuit.token.builder.Fact f) { + this.facts.add(f); + return this; + } + + public Block addFact(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.fact(s); + + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); + } + Tuple2 t = res.get(); -import java.util.*; + return addFact(t._2); + } -public class Block { - String context; - List facts; - List rules; - List checks; - List scopes; + public Block addRule(org.biscuitsec.biscuit.token.builder.Rule rule) { + this.rules.add(rule); + return this; + } - public Block() { - this.context = ""; - this.facts = new ArrayList<>(); - this.rules = new ArrayList<>(); - this.checks = new ArrayList<>(); - this.scopes = new ArrayList<>(); - } + public Block addRule(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.rule(s); - public Block add_fact(org.biscuitsec.biscuit.token.builder.Fact f) { - this.facts.add(f); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Block add_fact(String s) throws Error.Parser { - Either> res = - Parser.fact(s); + Tuple2 t = res.get(); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + return addRule(t._2); + } - Tuple2 t = res.get(); + public Block addCheck(org.biscuitsec.biscuit.token.builder.Check check) { + this.checks.add(check); + return this; + } - return add_fact(t._2); - } + public Block addCheck(String s) throws Error.Parser { + Either< + org.biscuitsec.biscuit.token.builder.parser.Error, + Tuple2> + res = Parser.check(s); - public Block add_rule(org.biscuitsec.biscuit.token.builder.Rule rule) { - this.rules.add(rule); - return this; + if (res.isLeft()) { + throw new Error.Parser(res.getLeft()); } - public Block add_rule(String s) throws Error.Parser { - Either> res = - Parser.rule(s); + Tuple2 t = res.get(); - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } - - Tuple2 t = res.get(); - - return add_rule(t._2); - } + return addCheck(t._2); + } - public Block add_check(org.biscuitsec.biscuit.token.builder.Check check) { - this.checks.add(check); - return this; - } + public Block addScope(org.biscuitsec.biscuit.token.builder.Scope scope) { + this.scopes.add(scope); + return this; + } - public Block add_check(String s) throws Error.Parser { - Either> res = - Parser.check(s); + public Block setContext(String context) { + this.context = context; + return this; + } - if (res.isLeft()) { - throw new Error.Parser(res.getLeft()); - } + public org.biscuitsec.biscuit.token.Block build() { + return build(defaultSymbolTable(), Option.none()); + } - Tuple2 t = res.get(); + public org.biscuitsec.biscuit.token.Block build(final Option externalKey) { + return build(defaultSymbolTable(), externalKey); + } - return add_check(t._2); - } + public org.biscuitsec.biscuit.token.Block build(SymbolTable symbolTable) { + return build(symbolTable, Option.none()); + } - public Block add_scope(org.biscuitsec.biscuit.token.builder.Scope scope) { - this.scopes.add(scope); - return this; + public org.biscuitsec.biscuit.token.Block build( + SymbolTable symbolTable, final Option externalKey) { + if (externalKey.isDefined()) { + symbolTable = new SymbolTable(); } + final int symbolStart = symbolTable.currentOffset(); + final int publicKeyStart = symbolTable.currentPublicKeyOffset(); - public Block set_context(String context) { - this.context = context; - return this; + List facts = new ArrayList<>(); + for (Fact f : this.facts) { + facts.add(f.convert(symbolTable)); } - - public org.biscuitsec.biscuit.token.Block build() { - return build(default_symbol_table(), Option.none()); + List rules = new ArrayList<>(); + for (Rule r : this.rules) { + rules.add(r.convert(symbolTable)); } - - public org.biscuitsec.biscuit.token.Block build(final Option externalKey) { - return build(default_symbol_table(), externalKey); + List checks = new ArrayList<>(); + for (Check c : this.checks) { + checks.add(c.convert(symbolTable)); } - - public org.biscuitsec.biscuit.token.Block build(SymbolTable symbols) { - return build(symbols, Option.none()); + List scopes = new ArrayList<>(); + for (Scope s : this.scopes) { + scopes.add(s.convert(symbolTable)); } + SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); - public org.biscuitsec.biscuit.token.Block build(SymbolTable symbols, final Option externalKey) { - if(externalKey.isDefined()) { - symbols = new SymbolTable(); - } - int symbol_start = symbols.currentOffset(); - int publicKeyStart = symbols.currentPublicKeyOffset(); - - List facts = new ArrayList<>(); - for(Fact f: this.facts) { - facts.add(f.convert(symbols)); - } - List rules = new ArrayList<>(); - for(Rule r: this.rules) { - rules.add(r.convert(symbols)); - } - List checks = new ArrayList<>(); - for(Check c: this.checks) { - checks.add(c.convert(symbols)); - } - List scopes = new ArrayList<>(); - for(Scope s: this.scopes) { - scopes.add(s.convert(symbols)); - } - SchemaVersion schemaVersion = new SchemaVersion(facts, rules, checks, scopes); - - SymbolTable block_symbols = new SymbolTable(); + SymbolTable blockSymbols = new SymbolTable(); - for (int i = symbol_start; i < symbols.symbols.size(); i++) { - block_symbols.add(symbols.symbols.get(i)); - } - - List publicKeys = new ArrayList<>(); - for (int i = publicKeyStart; i < symbols.currentPublicKeyOffset(); i++) { - publicKeys.add(symbols.publicKeys().get(i)); - } + for (int i = symbolStart; i < symbolTable.symbols().size(); i++) { + blockSymbols.add(symbolTable.symbols().get(i)); + } - return new org.biscuitsec.biscuit.token.Block(block_symbols, this.context, facts, rules, checks, - scopes, publicKeys, externalKey, schemaVersion.version()); + List publicKeys = new ArrayList<>(); + for (int i = publicKeyStart; i < symbolTable.currentPublicKeyOffset(); i++) { + publicKeys.add(symbolTable.getPublicKeys().get(i)); } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return new org.biscuitsec.biscuit.token.Block( + blockSymbols, + this.context, + facts, + rules, + checks, + scopes, + publicKeys, + externalKey, + schemaVersion.version()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Block block = (Block) o; + Block block = (Block) o; - if (!Objects.equals(context, block.context)) return false; - if (!Objects.equals(facts, block.facts)) return false; - if (!Objects.equals(rules, block.rules)) return false; - if (!Objects.equals(checks, block.checks)) return false; - return Objects.equals(scopes, block.scopes); + if (!Objects.equals(context, block.context)) { + return false; } - - @Override - public int hashCode() { - int result = context != null ? context.hashCode() : 0; - result = 31 * result + (facts != null ? facts.hashCode() : 0); - result = 31 * result + (rules != null ? rules.hashCode() : 0); - result = 31 * result + (checks != null ? checks.hashCode() : 0); - result = 31 * result + (scopes != null ? scopes.hashCode() : 0); - return result; + if (!Objects.equals(facts, block.facts)) { + return false; } - - public Block check_right(String right) { - ArrayList queries = new ArrayList<>(); - queries.add(rule( - "check_right", - Arrays.asList(s(right)), - Arrays.asList( - pred("resource", Arrays.asList(var("resource"))), - pred("operation", Arrays.asList(s(right))), - pred("right", Arrays.asList(var("resource"), s(right))) - ) - )); - return this.add_check(new org.biscuitsec.biscuit.token.builder.Check(One, queries)); + if (!Objects.equals(rules, block.rules)) { + return false; } - - public Block resource_prefix(String prefix) { - ArrayList queries = new ArrayList<>(); - - queries.add(constrained_rule( - "prefix", - Arrays.asList(var("resource")), - Arrays.asList(pred("resource", Arrays.asList(var("resource")))), - Arrays.asList(new Expression.Binary(Expression.Op.Prefix, new Expression.Value(var("resource")), - new Expression.Value(string(prefix)))) - )); - return this.add_check(new org.biscuitsec.biscuit.token.builder.Check(One, queries)); - } - - public Block resource_suffix(String suffix) { - ArrayList queries = new ArrayList<>(); - - queries.add(constrained_rule( - "suffix", - Arrays.asList(var("resource")), - Arrays.asList(pred("resource", Arrays.asList(var("resource")))), - Arrays.asList(new Expression.Binary(Expression.Op.Suffix, new Expression.Value(var("resource")), - new Expression.Value(string(suffix)))) - )); - return this.add_check(new org.biscuitsec.biscuit.token.builder.Check(One, queries)); - } - - public Block expiration_date(Date d) { - ArrayList queries = new ArrayList<>(); - - queries.add(constrained_rule( - "expiration", - Arrays.asList(var("date")), - Arrays.asList(pred("time", Arrays.asList(var("date")))), - Arrays.asList(new Expression.Binary(Expression.Op.LessOrEqual, new Expression.Value(var("date")), - new Expression.Value(date(d)))) - )); - return this.add_check(new org.biscuitsec.biscuit.token.builder.Check(One, queries)); + if (!Objects.equals(checks, block.checks)) { + return false; } + return Objects.equals(scopes, block.scopes); + } + + @Override + public int hashCode() { + int result = context != null ? context.hashCode() : 0; + result = 31 * result + (facts != null ? facts.hashCode() : 0); + result = 31 * result + (rules != null ? rules.hashCode() : 0); + result = 31 * result + (checks != null ? checks.hashCode() : 0); + result = 31 * result + (scopes != null ? scopes.hashCode() : 0); + return result; + } + + public Block checkRight(String right) { + ArrayList queries = new ArrayList<>(); + queries.add( + rule( + "check_right", + Arrays.asList(str(right)), + Arrays.asList( + pred("resource", Arrays.asList(var("resource"))), + pred("operation", Arrays.asList(str(right))), + pred("right", Arrays.asList(var("resource"), str(right)))))); + return this.addCheck(new org.biscuitsec.biscuit.token.builder.Check(ONE, queries)); + } + + public Block resourcePrefix(String prefix) { + ArrayList queries = new ArrayList<>(); + + queries.add( + constrainedRule( + "prefix", + Arrays.asList(var("resource")), + Arrays.asList(pred("resource", Arrays.asList(var("resource")))), + Arrays.asList( + new Expression.Binary( + Expression.Op.Prefix, + new Expression.Value(var("resource")), + new Expression.Value(string(prefix)))))); + return this.addCheck(new org.biscuitsec.biscuit.token.builder.Check(ONE, queries)); + } + + public Block resourceSuffix(String suffix) { + ArrayList queries = new ArrayList<>(); + + queries.add( + constrainedRule( + "suffix", + Arrays.asList(var("resource")), + Arrays.asList(pred("resource", Arrays.asList(var("resource")))), + Arrays.asList( + new Expression.Binary( + Expression.Op.Suffix, + new Expression.Value(var("resource")), + new Expression.Value(string(suffix)))))); + return this.addCheck(new org.biscuitsec.biscuit.token.builder.Check(ONE, queries)); + } + + public Block setExpirationDate(Date d) { + ArrayList queries = new ArrayList<>(); + + queries.add( + constrainedRule( + "expiration", + Arrays.asList(var("date")), + Arrays.asList(pred("time", Arrays.asList(var("date")))), + Arrays.asList( + new Expression.Binary( + Expression.Op.LessOrEqual, + new Expression.Value(var("date")), + new Expression.Value(date(d)))))); + return this.addCheck(new org.biscuitsec.biscuit.token.builder.Check(ONE, queries)); + } + + public String context() { + return context; + } + + public List facts() { + return Collections.unmodifiableList(facts); + } + + public List rules() { + return Collections.unmodifiableList(rules); + } + + public List checks() { + return Collections.unmodifiableList(checks); + } + + public List scopes() { + return scopes; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Check.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Check.java index 40db3d66..c6ac12b7 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Check.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Check.java @@ -1,72 +1,76 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.datalog.SymbolTable; +import static org.biscuitsec.biscuit.datalog.Check.Kind.ONE; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import org.biscuitsec.biscuit.datalog.SymbolTable; -import static org.biscuitsec.biscuit.datalog.Check.Kind.One; +public final class Check { + private final org.biscuitsec.biscuit.datalog.Check.Kind kind; + private final List queries; -public class Check { - private final org.biscuitsec.biscuit.datalog.Check.Kind kind; - List queries; + public Check(org.biscuitsec.biscuit.datalog.Check.Kind kind, List queries) { + this.kind = kind; + this.queries = queries; + } - public Check(org.biscuitsec.biscuit.datalog.Check.Kind kind, List queries) { - this.kind = kind; - this.queries = queries; - } + public Check(org.biscuitsec.biscuit.datalog.Check.Kind kind, Rule query) { + this.kind = kind; - public Check(org.biscuitsec.biscuit.datalog.Check.Kind kind, Rule query) { - this.kind = kind; + ArrayList r = new ArrayList<>(); + r.add(query); + queries = r; + } - ArrayList r = new ArrayList<>(); - r.add(query); - queries = r; + public org.biscuitsec.biscuit.datalog.Check convert(SymbolTable symbolTable) { + ArrayList queries = new ArrayList<>(); + + for (Rule q : this.queries) { + queries.add(q.convert(symbolTable)); } + return new org.biscuitsec.biscuit.datalog.Check(this.kind, queries); + } - public org.biscuitsec.biscuit.datalog.Check convert(SymbolTable symbols) { - ArrayList queries = new ArrayList<>(); + public static Check convertFrom(org.biscuitsec.biscuit.datalog.Check r, SymbolTable symbolTable) { + ArrayList queries = new ArrayList<>(); - for(Rule q: this.queries) { - queries.add(q.convert(symbols)); - } - return new org.biscuitsec.biscuit.datalog.Check(this.kind, queries); + for (org.biscuitsec.biscuit.datalog.Rule q : r.queries()) { + queries.add(Rule.convertFrom(q, symbolTable)); } - public static Check convert_from(org.biscuitsec.biscuit.datalog.Check r, SymbolTable symbols) { - ArrayList queries = new ArrayList<>(); + return new Check(r.kind(), queries); + } - for(org.biscuitsec.biscuit.datalog.Rule q: r.queries()) { - queries.add(Rule.convert_from(q, symbols)); - } + @Override + public String toString() { + final List qs = + queries.stream().map((q) -> q.bodyToString()).collect(Collectors.toList()); - return new Check(r.kind(), queries); + if (kind == ONE) { + return "check if " + String.join(" or ", qs); + } else { + return "check all " + String.join(" or ", qs); } + } - @Override - public String toString() { - final List qs = queries.stream().map((q) -> q.bodyToString()).collect(Collectors.toList()); - - if(kind == One) { - return "check if " + String.join(" or ", qs); - } else { - return "check all " + String.join(" or ", qs); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Check check = (Check) o; + Check check = (Check) o; - return queries != null ? queries.equals(check.queries) : check.queries == null; - } + return queries != null ? queries.equals(check.queries) : check.queries == null; + } - @Override - public int hashCode() { - return queries != null ? queries.hashCode() : 0; - } + @Override + public int hashCode() { + return queries != null ? queries.hashCode() : 0; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Expression.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Expression.java index 2f0c07a2..e2e99fef 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Expression.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Expression.java @@ -1,402 +1,487 @@ package org.biscuitsec.biscuit.token.builder; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.Set; import org.biscuitsec.biscuit.datalog.SymbolTable; -import java.util.*; - public abstract class Expression { - public org.biscuitsec.biscuit.datalog.expressions.Expression convert(SymbolTable symbols) { - ArrayList ops = new ArrayList<>(); - this.toOpcodes(symbols, ops); - - return new org.biscuitsec.biscuit.datalog.expressions.Expression(ops); - } - public static Expression convert_from(org.biscuitsec.biscuit.datalog.expressions.Expression e, SymbolTable symbols) { - ArrayList ops = new ArrayList<>(); - Deque stack = new ArrayDeque(16); - for(org.biscuitsec.biscuit.datalog.expressions.Op op: e.getOps()){ - if(op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Value) { - org.biscuitsec.biscuit.datalog.expressions.Op.Value v = (org.biscuitsec.biscuit.datalog.expressions.Op.Value) op; - stack.push(new Expression.Value(Term.convert_from(v.getValue(), symbols))); - } else if(op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Unary) { - org.biscuitsec.biscuit.datalog.expressions.Op.Unary v = (org.biscuitsec.biscuit.datalog.expressions.Op.Unary) op; - Expression e1 = stack.pop(); - - switch (v.getOp()) { - case Length: - stack.push(new Expression.Unary(Op.Length, e1)); - break; - case Negate: - stack.push(new Expression.Unary(Op.Negate, e1)); - break; - case Parens: - stack.push(new Expression.Unary(Op.Parens, e1)); - break; - default: - return null; - } - } else if (op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Binary) { - org.biscuitsec.biscuit.datalog.expressions.Op.Binary v = (org.biscuitsec.biscuit.datalog.expressions.Op.Binary) op; - Expression e2 = stack.pop(); - Expression e1 = stack.pop(); - - switch (v.getOp()) { - case LessThan: - stack.push(new Expression.Binary(Op.LessThan, e1, e2)); - break; - case GreaterThan: - stack.push(new Expression.Binary(Op.GreaterThan, e1, e2)); - break; - case LessOrEqual: - stack.push(new Expression.Binary(Op.LessOrEqual, e1, e2)); - break; - case GreaterOrEqual: - stack.push(new Expression.Binary(Op.GreaterOrEqual, e1, e2)); - break; - case Equal: - stack.push(new Expression.Binary(Op.Equal, e1, e2)); - break; - case NotEqual: - stack.push(new Expression.Binary(Op.NotEqual, e1, e2)); - break; - case Contains: - stack.push(new Expression.Binary(Op.Contains, e1, e2)); - break; - case Prefix: - stack.push(new Expression.Binary(Op.Prefix, e1, e2)); - break; - case Suffix: - stack.push(new Expression.Binary(Op.Suffix, e1, e2)); - break; - case Regex: - stack.push(new Expression.Binary(Op.Regex, e1, e2)); - break; - case Add: - stack.push(new Expression.Binary(Op.Add, e1, e2)); - break; - case Sub: - stack.push(new Expression.Binary(Op.Sub, e1, e2)); - break; - case Mul: - stack.push(new Expression.Binary(Op.Mul, e1, e2)); - break; - case Div: - stack.push(new Expression.Binary(Op.Div, e1, e2)); - break; - case And: - stack.push(new Expression.Binary(Op.And, e1, e2)); - break; - case Or: - stack.push(new Expression.Binary(Op.Or, e1, e2)); - break; - case Intersection: - stack.push(new Expression.Binary(Op.Intersection, e1, e2)); - break; - case Union: - stack.push(new Expression.Binary(Op.Union, e1, e2)); - break; - case BitwiseAnd: - stack.push(new Expression.Binary(Op.BitwiseAnd, e1, e2)); - break; - case BitwiseOr: - stack.push(new Expression.Binary(Op.BitwiseOr, e1, e2)); - break; - case BitwiseXor: - stack.push(new Expression.Binary(Op.BitwiseXor, e1, e2)); - break; - default: - return null; - } - } + public final org.biscuitsec.biscuit.datalog.expressions.Expression convert(SymbolTable symbolTable) { + ArrayList ops = new ArrayList<>(); + this.toOpcodes(symbolTable, ops); + + return new org.biscuitsec.biscuit.datalog.expressions.Expression(ops); + } + + public static Expression convertFrom( + org.biscuitsec.biscuit.datalog.expressions.Expression e, SymbolTable symbolTable) { + ArrayList ops = new ArrayList<>(); + Deque stack = new ArrayDeque(16); + for (org.biscuitsec.biscuit.datalog.expressions.Op op : e.getOps()) { + if (op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Value) { + org.biscuitsec.biscuit.datalog.expressions.Op.Value v = + (org.biscuitsec.biscuit.datalog.expressions.Op.Value) op; + stack.push(new Expression.Value(Term.convertFrom(v.getValue(), symbolTable))); + } else if (op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Unary) { + org.biscuitsec.biscuit.datalog.expressions.Op.Unary v = + (org.biscuitsec.biscuit.datalog.expressions.Op.Unary) op; + Expression e1 = stack.pop(); + + switch (v.getOp()) { + case Length: + stack.push(new Expression.Unary(Op.Length, e1)); + break; + case Negate: + stack.push(new Expression.Unary(Op.Negate, e1)); + break; + case Parens: + stack.push(new Expression.Unary(Op.Parens, e1)); + break; + default: + return null; } - - return stack.pop(); + } else if (op instanceof org.biscuitsec.biscuit.datalog.expressions.Op.Binary) { + org.biscuitsec.biscuit.datalog.expressions.Op.Binary v = + (org.biscuitsec.biscuit.datalog.expressions.Op.Binary) op; + Expression e2 = stack.pop(); + Expression e1 = stack.pop(); + + switch (v.getOp()) { + case LessThan: + stack.push(new Expression.Binary(Op.LessThan, e1, e2)); + break; + case GreaterThan: + stack.push(new Expression.Binary(Op.GreaterThan, e1, e2)); + break; + case LessOrEqual: + stack.push(new Expression.Binary(Op.LessOrEqual, e1, e2)); + break; + case GreaterOrEqual: + stack.push(new Expression.Binary(Op.GreaterOrEqual, e1, e2)); + break; + case Equal: + stack.push(new Expression.Binary(Op.Equal, e1, e2)); + break; + case NotEqual: + stack.push(new Expression.Binary(Op.NotEqual, e1, e2)); + break; + case Contains: + stack.push(new Expression.Binary(Op.Contains, e1, e2)); + break; + case Prefix: + stack.push(new Expression.Binary(Op.Prefix, e1, e2)); + break; + case Suffix: + stack.push(new Expression.Binary(Op.Suffix, e1, e2)); + break; + case Regex: + stack.push(new Expression.Binary(Op.Regex, e1, e2)); + break; + case Add: + stack.push(new Expression.Binary(Op.Add, e1, e2)); + break; + case Sub: + stack.push(new Expression.Binary(Op.Sub, e1, e2)); + break; + case Mul: + stack.push(new Expression.Binary(Op.Mul, e1, e2)); + break; + case Div: + stack.push(new Expression.Binary(Op.Div, e1, e2)); + break; + case And: + stack.push(new Expression.Binary(Op.And, e1, e2)); + break; + case Or: + stack.push(new Expression.Binary(Op.Or, e1, e2)); + break; + case Intersection: + stack.push(new Expression.Binary(Op.Intersection, e1, e2)); + break; + case Union: + stack.push(new Expression.Binary(Op.Union, e1, e2)); + break; + case BitwiseAnd: + stack.push(new Expression.Binary(Op.BitwiseAnd, e1, e2)); + break; + case BitwiseOr: + stack.push(new Expression.Binary(Op.BitwiseOr, e1, e2)); + break; + case BitwiseXor: + stack.push(new Expression.Binary(Op.BitwiseXor, e1, e2)); + break; + default: + return null; + } + } } - public abstract void toOpcodes(SymbolTable symbols, List ops); - public abstract void gatherVariables(Set variables); - - public enum Op { - Negate, - Parens, - LessThan, - GreaterThan, - LessOrEqual, - GreaterOrEqual, - Equal, - NotEqual, - Contains, - Prefix, - Suffix, - Regex, - Add, - Sub, - Mul, - Div, - And, - Or, - Length, - Intersection, - Union, - BitwiseAnd, - BitwiseOr, - BitwiseXor + return stack.pop(); + } + + public abstract void toOpcodes( + SymbolTable symbolTable, List ops); + + public abstract void gatherVariables(Set variables); + + public enum Op { + Negate, + Parens, + LessThan, + GreaterThan, + LessOrEqual, + GreaterOrEqual, + Equal, + NotEqual, + Contains, + Prefix, + Suffix, + Regex, + Add, + Sub, + Mul, + Div, + And, + Or, + Length, + Intersection, + Union, + BitwiseAnd, + BitwiseOr, + BitwiseXor + } + + public static final class Value extends Expression { + public final Term value; + + public Value(Term value) { + this.value = value; } - public final static class Value extends Expression { - public final Term value; - - public Value(Term value) { - this.value = value; - } - - public void toOpcodes(SymbolTable symbols, List ops) { - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Value(this.value.convert(symbols))); - } - - public void gatherVariables(Set variables) { - if(this.value instanceof Term.Variable) { - variables.add(((Term.Variable) this.value).value); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public void toOpcodes( + SymbolTable symbolTable, List ops) { + ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Value(this.value.convert(symbolTable))); + } - Value value1 = (Value) o; + public void gatherVariables(Set variables) { + if (this.value instanceof Term.Variable) { + variables.add(((Term.Variable) this.value).value); + } + } - return value != null ? value.equals(value1.value) : value1.value == null; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - @Override - public int hashCode() { - return value != null ? value.hashCode() : 0; - } + Value value1 = (Value) o; - @Override - public String toString() { - return value.toString(); - } + return value != null ? value.equals(value1.value) : value1.value == null; } - public final static class Unary extends Expression { - private final Op op; - private final Expression arg1; - - public Unary(Op op, Expression arg1) { - this.op = op; - this.arg1 = arg1; - } - - public void toOpcodes(SymbolTable symbols, List ops) { - this.arg1.toOpcodes(symbols, ops); - - switch (this.op) { - case Negate: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Unary(org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Negate)); - break; - case Parens: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Unary(org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Parens)); - break; - case Length: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Unary(org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Length)); - break; - } - } + @Override + public int hashCode() { + return value != null ? value.hashCode() : 0; + } - public void gatherVariables(Set variables) { - this.arg1.gatherVariables(variables); - } + @Override + public String toString() { + return value.toString(); + } + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public static final class Unary extends Expression { + private final Op op; + private final Expression arg1; - Unary unary = (Unary) o; + public Unary(Op op, Expression arg1) { + this.op = op; + this.arg1 = arg1; + } - if (op != unary.op) return false; - return arg1.equals(unary.arg1); - } + public void toOpcodes( + SymbolTable symbolTable, List ops) { + this.arg1.toOpcodes(symbolTable, ops); + + switch (this.op) { + case Negate: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Unary( + org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Negate)); + break; + case Parens: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Unary( + org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Parens)); + break; + case Length: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Unary( + org.biscuitsec.biscuit.datalog.expressions.Op.UnaryOp.Length)); + break; + default: + throw new RuntimeException("unmapped ops"); + } + } - @Override - public int hashCode() { - int result = op.hashCode(); - result = 31 * result + arg1.hashCode(); - return result; - } + public void gatherVariables(Set variables) { + this.arg1.gatherVariables(variables); + } - @Override - public String toString() { - switch(op) { - case Negate: - return "!"+arg1; - case Parens: - return "("+arg1+")"; - case Length: - return arg1.toString()+".length()"; - } - return ""; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Unary unary = (Unary) o; + + if (op != unary.op) { + return false; + } + return arg1.equals(unary.arg1); } - public final static class Binary extends Expression { - private final Op op; - private final Expression arg1; - private final Expression arg2; + @Override + public int hashCode() { + int result = op.hashCode(); + result = 31 * result + arg1.hashCode(); + return result; + } - public Binary(Op op, Expression arg1, Expression arg2) { - this.op = op; - this.arg1 = arg1; - this.arg2 = arg2; - } + @Override + public String toString() { + switch (op) { + case Negate: + return "!" + arg1; + case Parens: + return "(" + arg1 + ")"; + case Length: + return arg1.toString() + ".length()"; + default: + return ""; + } + } + } - public void toOpcodes(SymbolTable symbols, List ops) { - this.arg1.toOpcodes(symbols, ops); - this.arg2.toOpcodes(symbols, ops); - - switch (this.op) { - case LessThan: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.LessThan)); - break; - case GreaterThan: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.GreaterThan)); - break; - case LessOrEqual: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.LessOrEqual)); - break; - case GreaterOrEqual: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.GreaterOrEqual)); - break; - case Equal: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Equal)); - break; - case NotEqual: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.NotEqual)); - break; - case Contains: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Contains)); - break; - case Prefix: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Prefix)); - break; - case Suffix: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Suffix)); - break; - case Regex: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Regex)); - break; - case Add: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Add)); - break; - case Sub: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Sub)); - break; - case Mul: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Mul)); - break; - case Div: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Div)); - break; - case And: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.And)); - break; - case Or: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Or)); - break; - case Intersection: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Intersection)); - break; - case Union: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Union)); - break; - case BitwiseAnd: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseAnd)); - break; - case BitwiseOr: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseOr)); - break; - case BitwiseXor: - ops.add(new org.biscuitsec.biscuit.datalog.expressions.Op.Binary(org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseXor)); - break; - } - } + public static final class Binary extends Expression { + private final Op op; + private final Expression arg1; + private final Expression arg2; - public void gatherVariables(Set variables) { - this.arg1.gatherVariables(variables); - this.arg2.gatherVariables(variables); - } + public Binary(Op op, Expression arg1, Expression arg2) { + this.op = op; + this.arg1 = arg1; + this.arg2 = arg2; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public void toOpcodes( + SymbolTable symbolTable, List ops) { + this.arg1.toOpcodes(symbolTable, ops); + this.arg2.toOpcodes(symbolTable, ops); + + switch (this.op) { + case LessThan: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.LessThan)); + break; + case GreaterThan: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.GreaterThan)); + break; + case LessOrEqual: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.LessOrEqual)); + break; + case GreaterOrEqual: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.GreaterOrEqual)); + break; + case Equal: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Equal)); + break; + case NotEqual: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.NotEqual)); + break; + case Contains: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Contains)); + break; + case Prefix: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Prefix)); + break; + case Suffix: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Suffix)); + break; + case Regex: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Regex)); + break; + case Add: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Add)); + break; + case Sub: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Sub)); + break; + case Mul: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Mul)); + break; + case Div: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Div)); + break; + case And: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.And)); + break; + case Or: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Or)); + break; + case Intersection: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Intersection)); + break; + case Union: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.Union)); + break; + case BitwiseAnd: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseAnd)); + break; + case BitwiseOr: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseOr)); + break; + case BitwiseXor: + ops.add( + new org.biscuitsec.biscuit.datalog.expressions.Op.Binary( + org.biscuitsec.biscuit.datalog.expressions.Op.BinaryOp.BitwiseXor)); + break; + default: + throw new RuntimeException("unmapped ops"); + } + } - Binary binary = (Binary) o; + public void gatherVariables(Set variables) { + this.arg1.gatherVariables(variables); + this.arg2.gatherVariables(variables); + } - if (op != binary.op) return false; - if (!arg1.equals(binary.arg1)) return false; - return arg2.equals(binary.arg2); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Binary binary = (Binary) o; + + if (op != binary.op) { + return false; + } + if (!arg1.equals(binary.arg1)) { + return false; + } + return arg2.equals(binary.arg2); + } - @Override - public int hashCode() { - int result = op.hashCode(); - result = 31 * result + arg1.hashCode(); - result = 31 * result + arg2.hashCode(); - return result; - } + @Override + public int hashCode() { + int result = op.hashCode(); + result = 31 * result + arg1.hashCode(); + result = 31 * result + arg2.hashCode(); + return result; + } - @Override - public String toString() { - switch(op) { - case LessThan: - return arg1.toString() + " < " + arg2.toString(); - case GreaterThan: - return arg1.toString() + " > " + arg2.toString(); - case LessOrEqual: - return arg1.toString() + " <= " + arg2.toString(); - case GreaterOrEqual: - return arg1.toString() + " >= " + arg2.toString(); - case Equal: - return arg1.toString() + " == " + arg2.toString(); - case NotEqual: - return arg1.toString() + " != " + arg2.toString(); - case Contains: - return arg1.toString() + ".contains(" + arg2.toString()+")"; - case Prefix: - return arg1.toString() + ".starts_with(" + arg2.toString()+")"; - case Suffix: - return arg1.toString() + ".ends_with(" + arg2.toString()+")"; - case Regex: - return arg1.toString() + ".matches(" + arg2.toString()+")"; - case Add: - return arg1.toString() + " + " + arg2.toString(); - case Sub: - return arg1.toString() + " - " + arg2.toString(); - case Mul: - return arg1.toString() + " * " + arg2.toString(); - case Div: - return arg1.toString() + " / " + arg2.toString(); - case And: - return arg1.toString() + " && " + arg2.toString(); - case Or: - return arg1.toString() + " || " + arg2.toString(); - case Intersection: - return arg1.toString() + ".intersection(" + arg2.toString()+")"; - case Union: - return arg1.toString() + ".union(" + arg2.toString()+")"; - case BitwiseAnd: - return arg1.toString() + " & " + arg2.toString(); - case BitwiseOr: - return arg1.toString() + " | " + arg2.toString(); - case BitwiseXor: - return arg1.toString() + " ^ " + arg2.toString(); - } - return ""; - } + @Override + public String toString() { + switch (op) { + case LessThan: + return arg1.toString() + " < " + arg2.toString(); + case GreaterThan: + return arg1.toString() + " > " + arg2.toString(); + case LessOrEqual: + return arg1.toString() + " <= " + arg2.toString(); + case GreaterOrEqual: + return arg1.toString() + " >= " + arg2.toString(); + case Equal: + return arg1.toString() + " == " + arg2.toString(); + case NotEqual: + return arg1.toString() + " != " + arg2.toString(); + case Contains: + return arg1.toString() + ".contains(" + arg2.toString() + ")"; + case Prefix: + return arg1.toString() + ".starts_with(" + arg2.toString() + ")"; + case Suffix: + return arg1.toString() + ".ends_with(" + arg2.toString() + ")"; + case Regex: + return arg1.toString() + ".matches(" + arg2.toString() + ")"; + case Add: + return arg1.toString() + " + " + arg2.toString(); + case Sub: + return arg1.toString() + " - " + arg2.toString(); + case Mul: + return arg1.toString() + " * " + arg2.toString(); + case Div: + return arg1.toString() + " / " + arg2.toString(); + case And: + return arg1.toString() + " && " + arg2.toString(); + case Or: + return arg1.toString() + " || " + arg2.toString(); + case Intersection: + return arg1.toString() + ".intersection(" + arg2.toString() + ")"; + case Union: + return arg1.toString() + ".union(" + arg2.toString() + ")"; + case BitwiseAnd: + return arg1.toString() + " & " + arg2.toString(); + case BitwiseOr: + return arg1.toString() + " | " + arg2.toString(); + case BitwiseXor: + return arg1.toString() + " ^ " + arg2.toString(); + default: + return ""; + } } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Fact.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Fact.java index a06c5bf2..e7cb42fb 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Fact.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Fact.java @@ -1,133 +1,147 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.error.FailedCheck; import io.vavr.control.Option; - import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.error.FailedCheck; -public class Fact implements Cloneable{ - Predicate predicate; - Option>> variables; - - public Fact(String name, List terms) { - Map> variables = new HashMap>(); - for (Term term : terms) { - if (term instanceof Term.Variable) { - variables.put(((Term.Variable) term).value, Option.none()); - } - } - this.predicate = new Predicate(name, terms); - this.variables = Option.some(variables); - } - - public Fact(Predicate p) { - this.predicate = p; - this.variables = Option.none(); - } - - private Fact(Predicate predicate, Option>> variables){ - this.predicate = predicate; - this.variables = variables; - } - - public void validate() throws Error.Language { - if (!this.variables.isEmpty()) { - List invalid_variables = variables.get().entrySet().stream().flatMap( - e -> { - if (e.getValue().isEmpty()) { - return Stream.of(e.getKey()); - } else { - return Stream.empty(); - } - }).collect(Collectors.toList()); - if (!invalid_variables.isEmpty()) { - throw new Error.Language(new FailedCheck.LanguageError.Builder(invalid_variables)); - } - } - } - - public Fact set(String name, Term term) throws Error.Language { - if (this.variables.isEmpty()) { - throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable(name)); - } - Map> _variables = this.variables.get(); - Option r = _variables.get(name); - if (r != null) { - _variables.put(name, Option.some(term)); - } else { - throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable(name)); - } - return this; - } - - public Fact apply_variables() { - this.variables.forEach( - _variables -> { - this.predicate.terms = this.predicate.terms.stream().flatMap(t -> { - if(t instanceof Term.Variable){ - Option term = _variables.getOrDefault(((Term.Variable) t).value, Option.none()); - return term.map(_t -> Stream.of(_t)).getOrElse(Stream.empty()); - } else return Stream.of(t); - }).collect(Collectors.toList()); - }); - return this; - } - - public org.biscuitsec.biscuit.datalog.Fact convert(SymbolTable symbols) { - Fact f = this.clone(); - f.apply_variables(); - return new org.biscuitsec.biscuit.datalog.Fact(f.predicate.convert(symbols)); - } +public final class Fact implements Cloneable { + Predicate predicate; + Option>> variables; - public static Fact convert_from(org.biscuitsec.biscuit.datalog.Fact f, SymbolTable symbols) { - return new Fact(Predicate.convert_from(f.predicate(), symbols)); + public Fact(String name, List terms) { + Map> variables = new HashMap>(); + for (Term term : terms) { + if (term instanceof Term.Variable) { + variables.put(((Term.Variable) term).value, Option.none()); + } } - - @Override - public String toString() { - Fact f = this.clone(); - f.apply_variables(); - return f.predicate.toString(); + this.predicate = new Predicate(name, terms); + this.variables = Option.some(variables); + } + + public Fact(Predicate p) { + this.predicate = p; + this.variables = Option.none(); + } + + private Fact(Predicate predicate, Option>> variables) { + this.predicate = predicate; + this.variables = variables; + } + + public void validate() throws Error.Language { + if (!this.variables.isEmpty()) { + List invalidVariables = + variables.get().entrySet().stream() + .flatMap( + e -> { + if (e.getValue().isEmpty()) { + return Stream.of(e.getKey()); + } else { + return Stream.empty(); + } + }) + .collect(Collectors.toList()); + if (!invalidVariables.isEmpty()) { + throw new Error.Language(new FailedCheck.LanguageError.Builder(invalidVariables)); + } } + } - public String name() { - return this.predicate.name; + public Fact set(String name, Term term) throws Error.Language { + if (this.variables.isEmpty()) { + throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable(name)); } - - public List terms() { - return this.predicate.terms; + Map> localVariables = this.variables.get(); + Option r = localVariables.get(name); + if (r != null) { + localVariables.put(name, Option.some(term)); + } else { + throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable(name)); } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Fact fact = (Fact) o; - - return predicate != null ? predicate.equals(fact.predicate) : fact.predicate == null; + return this; + } + + public Fact applyVariables() { + this.variables.forEach( + laVariables -> { + this.predicate.terms = + this.predicate.terms.stream() + .flatMap( + t -> { + if (t instanceof Term.Variable) { + Option term = + laVariables.getOrDefault(((Term.Variable) t).value, Option.none()); + return term.map(t2 -> Stream.of(t2)).getOrElse(Stream.empty()); + } else { + return Stream.of(t); + } + }) + .collect(Collectors.toList()); + }); + return this; + } + + public org.biscuitsec.biscuit.datalog.Fact convert(SymbolTable symbolTable) { + Fact f = this.clone(); + f.applyVariables(); + return new org.biscuitsec.biscuit.datalog.Fact(f.predicate.convert(symbolTable)); + } + + public static Fact convertFrom(org.biscuitsec.biscuit.datalog.Fact f, SymbolTable symbolTable) { + return new Fact(Predicate.convertFrom(f.predicate(), symbolTable)); + } + + @Override + public String toString() { + Fact f = this.clone(); + f.applyVariables(); + return f.predicate.toString(); + } + + public String name() { + return this.predicate.name; + } + + public List terms() { + return this.predicate.terms; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - return predicate != null ? predicate.hashCode() : 0; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public Fact clone(){ - Predicate p = this.predicate.clone(); - Option>> variables = this.variables.map(_v -> - { - Map> m = new HashMap<>(); - m.putAll(_v); - return m; - }); - return new Fact(p, variables); - } + Fact fact = (Fact) o; + + return predicate != null ? predicate.equals(fact.predicate) : fact.predicate == null; + } + + @Override + public int hashCode() { + return predicate != null ? predicate.hashCode() : 0; + } + + @Override + public Fact clone() { + Predicate p = this.predicate.clone(); + Option>> variables = + this.variables.map( + v -> { + Map> m = new HashMap<>(); + m.putAll(v); + return m; + }); + return new Fact(p, variables); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Predicate.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Predicate.java index 84ea49db..66164b67 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Predicate.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Predicate.java @@ -1,78 +1,85 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.datalog.SymbolTable; - import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import org.biscuitsec.biscuit.datalog.SymbolTable; -public class Predicate implements Cloneable { - String name; - List terms; +public final class Predicate implements Cloneable { + String name; + List terms; - public Predicate(String name, List terms) { - this.name = name; - this.terms = terms; - } + public Predicate(String name, List terms) { + this.name = name; + this.terms = terms; + } - public String getName() { - return name; - } + public String getName() { + return name; + } - public List getTerms() { - return terms; - } + public List getTerms() { + return terms; + } - public org.biscuitsec.biscuit.datalog.Predicate convert(SymbolTable symbols) { - long name = symbols.insert(this.name); - ArrayList terms = new ArrayList<>(); + public org.biscuitsec.biscuit.datalog.Predicate convert(SymbolTable symbolTable) { + long name = symbolTable.insert(this.name); + ArrayList terms = new ArrayList<>(); - for(Term a: this.terms) { - terms.add(a.convert(symbols)); - } - - return new org.biscuitsec.biscuit.datalog.Predicate(name, terms); + for (Term a : this.terms) { + terms.add(a.convert(symbolTable)); } - public static Predicate convert_from(org.biscuitsec.biscuit.datalog.Predicate p, SymbolTable symbols) { - String name = symbols.print_symbol((int) p.name()); - List terms = new ArrayList<>(); - for(org.biscuitsec.biscuit.datalog.Term t: p.terms()) { - terms.add(t.toTerm(symbols)); - } - - return new Predicate(name, terms); - } + return new org.biscuitsec.biscuit.datalog.Predicate(name, terms); + } - @Override - public String toString() { - final List i = terms.stream().map((term) -> term.toString()).collect(Collectors.toList()); - return ""+name+"("+String.join(", ", i)+")"; + public static Predicate convertFrom( + org.biscuitsec.biscuit.datalog.Predicate p, SymbolTable symbolTable) { + String name = symbolTable.formatSymbol((int) p.name()); + List terms = new ArrayList<>(); + for (org.biscuitsec.biscuit.datalog.Term t : p.terms()) { + terms.add(t.toTerm(symbolTable)); } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + return new Predicate(name, terms); + } - Predicate predicate = (Predicate) o; + @Override + public String toString() { + final List i = + terms.stream().map((term) -> term.toString()).collect(Collectors.toList()); + return "" + name + "(" + String.join(", ", i) + ")"; + } - if (name != null ? !name.equals(predicate.name) : predicate.name != null) return false; - return terms != null ? terms.equals(predicate.terms) : predicate.terms == null; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public int hashCode() { - int result = name != null ? name.hashCode() : 0; - result = 31 * result + (terms != null ? terms.hashCode() : 0); - return result; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public Predicate clone(){ - String name = this.name; - List terms = new ArrayList(this.terms.size()); - terms.addAll(this.terms); - return new Predicate(name, terms); + Predicate predicate = (Predicate) o; + + if (name != null ? !name.equals(predicate.name) : predicate.name != null) { + return false; } + return terms != null ? terms.equals(predicate.terms) : predicate.terms == null; + } + + @Override + public int hashCode() { + int result = name != null ? name.hashCode() : 0; + result = 31 * result + (terms != null ? terms.hashCode() : 0); + return result; + } + + @Override + public Predicate clone() { + String name = this.name; + List terms = new ArrayList(this.terms.size()); + terms.addAll(this.terms); + return new Predicate(name, terms); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Rule.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Rule.java index 1ffbbae4..f17053ba 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Rule.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Rule.java @@ -1,237 +1,283 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.error.FailedCheck; import io.vavr.control.Either; import io.vavr.control.Option; - -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.error.FailedCheck; -public class Rule implements Cloneable { - Predicate head; - List body; - List expressions; - Option>> variables; - List scopes; - - public Rule(Predicate head, List body, List expressions, List scopes) { - Map> variables = new HashMap<>(); - this.head = head; - this.body = body; - this.expressions = expressions; - this.scopes = scopes; - for (Term t : head.terms) { - if (t instanceof Term.Variable) { - variables.put(((Term.Variable) t).value, Option.none()); - } - } - for (Predicate p : body) { - for (Term t : p.terms) { - if (t instanceof Term.Variable) { - variables.put(((Term.Variable) t).value, Option.none()); - } - } - } - for (Expression e : expressions) { - if (e instanceof Expression.Value) { - Expression.Value ev = (Expression.Value) e; - if (ev.value instanceof Term.Variable) - variables.put(((Term.Variable) ev.value).value, Option.none()); - } +public final class Rule implements Cloneable { + Predicate head; + List body; + List expressions; + Option>> variables; + List scopes; + + public Rule( + Predicate head, List body, List expressions, List scopes) { + this.head = head; + this.body = body; + this.expressions = expressions; + this.scopes = scopes; + Map> variables = new HashMap<>(); + for (Term t : head.terms) { + if (t instanceof Term.Variable) { + variables.put(((Term.Variable) t).value, Option.none()); + } + } + for (Predicate p : body) { + for (Term t : p.terms) { + if (t instanceof Term.Variable) { + variables.put(((Term.Variable) t).value, Option.none()); } - this.variables = Option.some(variables); - } - - @Override - public Rule clone() { - Predicate head = this.head.clone(); - List body = new ArrayList<>(); - body.addAll(this.body); - List expressions = new ArrayList<>(); - expressions.addAll(this.expressions); - List scopes = new ArrayList<>(); - scopes.addAll(this.scopes); - return new Rule(head, body, expressions, scopes); - } - - public void set(String name, Term term) throws Error.Language { - if (this.variables.isDefined()) { - Option> t = Option.of(this.variables.get().get(name)); - if (t.isDefined()) { - this.variables.get().put(name, Option.some(term)); - } else { - throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable("name")); - } - } else { - throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable("name")); + } + } + for (Expression e : expressions) { + if (e instanceof Expression.Value) { + Expression.Value ev = (Expression.Value) e; + if (ev.value instanceof Term.Variable) { + variables.put(((Term.Variable) ev.value).value, Option.none()); } + } } - - public void apply_variables() { - this.variables.forEach( - _variables -> { - this.head.terms = this.head.terms.stream().flatMap(t -> { + this.variables = Option.some(variables); + } + + @Override + public Rule clone() { + List body = new ArrayList<>(); + body.addAll(this.body); + List expressions = new ArrayList<>(); + expressions.addAll(this.expressions); + List scopes = new ArrayList<>(); + scopes.addAll(this.scopes); + Predicate head = this.head.clone(); + return new Rule(head, body, expressions, scopes); + } + + public void set(String name, Term term) throws Error.Language { + if (this.variables.isDefined()) { + Option> t = Option.of(this.variables.get().get(name)); + if (t.isDefined()) { + this.variables.get().put(name, Option.some(term)); + } else { + throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable("name")); + } + } else { + throw new Error.Language(new FailedCheck.LanguageError.UnknownVariable("name")); + } + } + + public void applyVariables() { + this.variables.forEach( + laVariables -> { + this.head.terms = + this.head.terms.stream() + .flatMap( + t -> { if (t instanceof Term.Variable) { - Option term = _variables.getOrDefault(((Term.Variable) t).value, Option.none()); - return term.map(_t -> Stream.of(_t)).getOrElse(Stream.of(t)); - } else return Stream.of(t); - }).collect(Collectors.toList()); - for (Predicate p : this.body) { - p.terms = p.terms.stream().flatMap(t -> { - if (t instanceof Term.Variable) { - Option term = _variables.getOrDefault(((Term.Variable) t).value, Option.none()); - return term.map(_t -> Stream.of(_t)).getOrElse(Stream.of(t)); - } else return Stream.of(t); - }).collect(Collectors.toList()); - } - this.expressions = this.expressions.stream().flatMap( - e -> { - if (e instanceof Expression.Value) { - Expression.Value ev = (Expression.Value) e; - if (ev.value instanceof Term.Variable) { - Option t = _variables.getOrDefault(((Term.Variable) ev.value).value, Option.none()); - if (t.isDefined()) { - return Stream.of(new Expression.Value(t.get())); - } - } - } - return Stream.of(e); - }).collect(Collectors.toList()); - }); - } - - public Either validate_variables() { - Set free_variables = this.head.terms.stream().flatMap(t -> { - if (t instanceof Term.Variable) { - return Stream.of(((Term.Variable) t).value); - } else return Stream.empty(); - }).collect(Collectors.toSet()); - - for(Expression e: this.expressions) { - e.gatherVariables(free_variables); - } - if (free_variables.isEmpty()) { - return Either.right(this); - } - - for (Predicate p : this.body) { - for (Term term : p.terms) { - if (term instanceof Term.Variable) { - free_variables.remove(((Term.Variable) term).value); - if (free_variables.isEmpty()) { - return Either.right(this); - } - } - } - } - - return Either.left("rule head or expressions contains variables that are not used in predicates of the rule's body: " + free_variables.toString()); + Option term = + laVariables.getOrDefault(((Term.Variable) t).value, Option.none()); + return term.map(t2 -> Stream.of(t2)).getOrElse(Stream.of(t)); + } else { + return Stream.of(t); + } + }) + .collect(Collectors.toList()); + for (Predicate p : this.body) { + p.terms = + p.terms.stream() + .flatMap( + t -> { + if (t instanceof Term.Variable) { + Option term = + laVariables.getOrDefault(((Term.Variable) t).value, Option.none()); + return term.map(t2 -> Stream.of(t2)).getOrElse(Stream.of(t)); + } else { + return Stream.of(t); + } + }) + .collect(Collectors.toList()); + } + this.expressions = + this.expressions.stream() + .flatMap( + e -> { + if (e instanceof Expression.Value) { + Expression.Value ev = (Expression.Value) e; + if (ev.value instanceof Term.Variable) { + Option t = + laVariables.getOrDefault( + ((Term.Variable) ev.value).value, Option.none()); + if (t.isDefined()) { + return Stream.of(new Expression.Value(t.get())); + } + } + } + return Stream.of(e); + }) + .collect(Collectors.toList()); + }); + } + + public Either validateVariables() { + Set freeVariables = + this.head.terms.stream() + .flatMap( + t -> { + if (t instanceof Term.Variable) { + return Stream.of(((Term.Variable) t).value); + } else { + return Stream.empty(); + } + }) + .collect(Collectors.toSet()); + + for (Expression e : this.expressions) { + e.gatherVariables(freeVariables); + } + if (freeVariables.isEmpty()) { + return Either.right(this); } - public org.biscuitsec.biscuit.datalog.Rule convert(SymbolTable symbols) { - Rule r = this.clone(); - r.apply_variables(); - org.biscuitsec.biscuit.datalog.Predicate head = r.head.convert(symbols); - ArrayList body = new ArrayList<>(); - ArrayList expressions = new ArrayList<>(); - ArrayList scopes = new ArrayList<>(); - - - for (Predicate p : r.body) { - body.add(p.convert(symbols)); - } - - for (Expression e : r.expressions) { - expressions.add(e.convert(symbols)); + for (Predicate p : this.body) { + for (Term term : p.terms) { + if (term instanceof Term.Variable) { + freeVariables.remove(((Term.Variable) term).value); + if (freeVariables.isEmpty()) { + return Either.right(this); + } } + } + } - for (Scope s : r.scopes) { - scopes.add(s.convert(symbols)); - } + return Either.left( + "rule head or expressions contains variables that are not " + + "used in predicates of the rule's body: " + + freeVariables.toString()); + } + + public org.biscuitsec.biscuit.datalog.Rule convert(SymbolTable symbolTable) { + Rule r = this.clone(); + r.applyVariables(); + ArrayList body = new ArrayList<>(); + ArrayList expressions = + new ArrayList<>(); + ArrayList scopes = new ArrayList<>(); + + for (Predicate p : r.body) { + body.add(p.convert(symbolTable)); + } - return new org.biscuitsec.biscuit.datalog.Rule(head, body, expressions, scopes); + for (Expression e : r.expressions) { + expressions.add(e.convert(symbolTable)); } - public static Rule convert_from(org.biscuitsec.biscuit.datalog.Rule r, SymbolTable symbols) { - Predicate head = Predicate.convert_from(r.head(), symbols); + for (Scope s : r.scopes) { + scopes.add(s.convert(symbolTable)); + } + org.biscuitsec.biscuit.datalog.Predicate head = r.head.convert(symbolTable); + return new org.biscuitsec.biscuit.datalog.Rule(head, body, expressions, scopes); + } - ArrayList body = new ArrayList<>(); - ArrayList expressions = new ArrayList<>(); - ArrayList scopes = new ArrayList<>(); + public static Rule convertFrom(org.biscuitsec.biscuit.datalog.Rule r, SymbolTable symbolTable) { + ArrayList body = new ArrayList<>(); + ArrayList expressions = new ArrayList<>(); + ArrayList scopes = new ArrayList<>(); + for (org.biscuitsec.biscuit.datalog.Predicate p : r.body()) { + body.add(Predicate.convertFrom(p, symbolTable)); + } - for (org.biscuitsec.biscuit.datalog.Predicate p : r.body()) { - body.add(Predicate.convert_from(p, symbols)); - } + for (org.biscuitsec.biscuit.datalog.expressions.Expression e : r.expressions()) { + expressions.add(Expression.convertFrom(e, symbolTable)); + } - for (org.biscuitsec.biscuit.datalog.expressions.Expression e : r.expressions()) { - expressions.add(Expression.convert_from(e, symbols)); - } + for (org.biscuitsec.biscuit.datalog.Scope s : r.scopes()) { + scopes.add(Scope.convertFrom(s, symbolTable)); + } - for (org.biscuitsec.biscuit.datalog.Scope s : r.scopes()) { - scopes.add(Scope.convert_from(s, symbols)); - } + Predicate head = Predicate.convertFrom(r.head(), symbolTable); + return new Rule(head, body, expressions, scopes); + } - return new Rule(head, body, expressions, scopes); + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Rule rule = (Rule) o; + Rule rule = (Rule) o; - if (head != null ? !head.equals(rule.head) : rule.head != null) return false; - if (body != null ? !body.equals(rule.body) : rule.body != null) return false; - if (scopes != null ? !scopes.equals(rule.scopes) : rule.scopes != null) return false; - return expressions != null ? expressions.equals(rule.expressions) : rule.expressions == null; + if (head != null ? !head.equals(rule.head) : rule.head != null) { + return false; } - - @Override - public int hashCode() { - int result = head != null ? head.hashCode() : 0; - result = 31 * result + (body != null ? body.hashCode() : 0); - result = 31 * result + (expressions != null ? expressions.hashCode() : 0); - result = 31 * result + (scopes != null ? scopes.hashCode() : 0); - return result; + if (body != null ? !body.equals(rule.body) : rule.body != null) { + return false; + } + if (scopes != null ? !scopes.equals(rule.scopes) : rule.scopes != null) { + return false; + } + return expressions != null ? expressions.equals(rule.expressions) : rule.expressions == null; + } + + @Override + public int hashCode() { + int result = head != null ? head.hashCode() : 0; + result = 31 * result + (body != null ? body.hashCode() : 0); + result = 31 * result + (expressions != null ? expressions.hashCode() : 0); + result = 31 * result + (scopes != null ? scopes.hashCode() : 0); + return result; + } + + public String bodyToString() { + Rule r = this.clone(); + r.applyVariables(); + String res = ""; + + if (!r.body.isEmpty()) { + final List b = + r.body.stream().map((pred) -> pred.toString()).collect(Collectors.toList()); + res += String.join(", ", b); } - public String bodyToString() { - Rule r = this.clone(); - r.apply_variables(); - String res = ""; - - if(!r.body.isEmpty()) { - final List b = r.body.stream().map((pred) -> pred.toString()).collect(Collectors.toList()); - res += String.join(", ", b); - } + if (!r.expressions.isEmpty()) { + if (!r.body.isEmpty()) { + res += ", "; + } + final List e = + r.expressions.stream() + .map((expression) -> expression.toString()) + .collect(Collectors.toList()); + res += String.join(", ", e); + } - if (!r.expressions.isEmpty()) { - if(!r.body.isEmpty()) { - res += ", "; - } - final List e = r.expressions.stream().map((expression) -> expression.toString()).collect(Collectors.toList()); - res += String.join(", ", e); - } + if (!r.scopes.isEmpty()) { + if (!r.body.isEmpty() || !r.expressions.isEmpty()) { + res += " "; + } + final List e = + r.scopes.stream().map((scope) -> scope.toString()).collect(Collectors.toList()); + res += "trusting " + String.join(", ", e); + } - if(!r.scopes.isEmpty()) { - if(!r.body.isEmpty() || !r.expressions.isEmpty()) { - res += " "; - } - final List e = r.scopes.stream().map((scope) -> scope.toString()).collect(Collectors.toList()); - res += "trusting " + String.join(", ", e); - } + return res; + } - return res; - } - @Override - public String toString() { - Rule r = this.clone(); - r.apply_variables(); - return r.head.toString() + " <- " + bodyToString(); - } + @Override + public String toString() { + Rule r = this.clone(); + r.applyVariables(); + return r.head.toString() + " <- " + bodyToString(); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Scope.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Scope.java index 3e25a8c1..4c0436e2 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Scope.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Scope.java @@ -1,122 +1,132 @@ package org.biscuitsec.biscuit.token.builder; +import java.util.Objects; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.SymbolTable; -import java.util.Objects; - -public class Scope { - enum Kind { - Authority, - Previous, - PublicKey, - Parameter, - } - - Kind kind; - PublicKey publicKey; - String parameter; - - private Scope(Kind kind) { - this.kind = kind; - this.publicKey = null; - this.parameter = ""; - } - - private Scope(Kind kind, PublicKey publicKey) { - this.kind = kind; - this.publicKey = publicKey; - this.parameter = ""; - } - - private Scope(Kind kind, String parameter) { - this.kind = kind; - this.publicKey = null; - this.parameter = parameter; +public final class Scope { + enum Kind { + Authority, + Previous, + PublicKey, + Parameter, + } + + private Kind kind; + private PublicKey publicKey; + private String parameter; + + private Scope(Kind kind) { + this.kind = kind; + this.publicKey = null; + this.parameter = ""; + } + + private Scope(Kind kind, PublicKey publicKey) { + this.kind = kind; + this.publicKey = publicKey; + this.parameter = ""; + } + + private Scope(Kind kind, String parameter) { + this.kind = kind; + this.publicKey = null; + this.parameter = parameter; + } + + public static Scope authority() { + return new Scope(Kind.Authority); + } + + public static Scope previous() { + return new Scope(Kind.Previous); + } + + public static Scope publicKey(PublicKey publicKey) { + return new Scope(Kind.PublicKey, publicKey); + } + + public static Scope parameter(String parameter) { + return new Scope(Kind.Parameter, parameter); + } + + public org.biscuitsec.biscuit.datalog.Scope convert(SymbolTable symbolTable) { + switch (this.kind) { + case Authority: + return org.biscuitsec.biscuit.datalog.Scope.authority(); + case Previous: + return org.biscuitsec.biscuit.datalog.Scope.previous(); + case Parameter: + // FIXME + return null; + // throw new Exception("Remaining parameter: " + this.parameter); + case PublicKey: + return org.biscuitsec.biscuit.datalog.Scope.publicKey(symbolTable.insert(this.publicKey)); + default: + return null; } + } - public static Scope authority() { + public static Scope convertFrom(org.biscuitsec.biscuit.datalog.Scope scope, SymbolTable symbolTable) { + switch (scope.kind()) { + case Authority: return new Scope(Kind.Authority); - } - - public static Scope previous() { + case Previous: return new Scope(Kind.Previous); + case PublicKey: + // FIXME error management should bubble up here + return new Scope(Kind.PublicKey, symbolTable.getPublicKey((int) scope.getPublicKey()).get()); + default: + return null; } - public static Scope publicKey(PublicKey publicKey) { - return new Scope(Kind.PublicKey, publicKey); - } + // FIXME error management should bubble up here + // throw new Exception("panic"); + // return null; - public static Scope parameter(String parameter) { - return new Scope(Kind.Parameter, parameter); - } + } - public org.biscuitsec.biscuit.datalog.Scope convert(SymbolTable symbols) { - switch (this.kind) { - case Authority: - return org.biscuitsec.biscuit.datalog.Scope.authority(); - case Previous: - return org.biscuitsec.biscuit.datalog.Scope.previous(); - case Parameter: - //FIXME - return null; - //throw new Exception("Remaining parameter: "+this.parameter); - case PublicKey: - return org.biscuitsec.biscuit.datalog.Scope.publicKey(symbols.insert(this.publicKey)); - } - //FIXME - return null; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - public static Scope convert_from(org.biscuitsec.biscuit.datalog.Scope scope, SymbolTable symbols) { - switch (scope.kind()) { - case Authority: - return new Scope(Kind.Authority); - case Previous: - return new Scope(Kind.Previous); - case PublicKey: - //FIXME error management should bubble up here - return new Scope(Kind.PublicKey, symbols.get_pk((int) scope.publicKey()).get()); - } - - //FIXME error management should bubble up here - //throw new Exception("panic"); - return null; - + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Scope scope = (Scope) o; + Scope scope = (Scope) o; - if (kind != scope.kind) return false; - if (!Objects.equals(publicKey, scope.publicKey)) return false; - return Objects.equals(parameter, scope.parameter); + if (kind != scope.kind) { + return false; } - - @Override - public int hashCode() { - int result = kind.hashCode(); - result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0); - result = 31 * result + (parameter != null ? parameter.hashCode() : 0); - return result; + if (!Objects.equals(publicKey, scope.publicKey)) { + return false; } - - @Override - public String toString() { - switch (this.kind) { - case Authority: - return "authority"; - case Previous: - return "previous"; - case Parameter: - return "{" + this.parameter + "}"; - case PublicKey: - return this.publicKey.toString(); - } - return null; + return Objects.equals(parameter, scope.parameter); + } + + @Override + public int hashCode() { + int result = kind.hashCode(); + result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0); + result = 31 * result + (parameter != null ? parameter.hashCode() : 0); + return result; + } + + @Override + public String toString() { + switch (this.kind) { + case Authority: + return "authority"; + case Previous: + return "previous"; + case Parameter: + return "{" + this.parameter + "}"; + case PublicKey: + return this.publicKey.toString(); + default: } + return null; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Term.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Term.java index 47c08c58..4c72885b 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Term.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Term.java @@ -1,7 +1,5 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.datalog.SymbolTable; - import java.time.Instant; import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; @@ -9,278 +7,310 @@ import java.util.Collections; import java.util.HashSet; import java.util.Objects; +import org.biscuitsec.biscuit.datalog.SymbolTable; public abstract class Term { - abstract public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols); - static public Term convert_from(org.biscuitsec.biscuit.datalog.Term id, SymbolTable symbols) { - return id.toTerm(symbols); - } + public abstract org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable); + + public static Term convertFrom(org.biscuitsec.biscuit.datalog.Term id, SymbolTable symbols) { + return id.toTerm(symbols); + } - public static class Str extends Term { - final String value; + public static final class Str extends Term { + final String value; - public Str(String value) { - this.value = value; - } + public Str(String value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Str(symbols.insert(this.value)); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Str(symbolTable.insert(this.value)); + } - public String getValue() { - return value; - } + public String getValue() { + return value; + } - @Override - public String toString() { - return "\""+value+"\""; - } + @Override + public String toString() { + return "\"" + value + "\""; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Str s = (Str) o; - return Objects.equals(value, s.value); - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Str s = (Str) o; + return Objects.equals(value, s.value); + } - @Override - public int hashCode() { - return value.hashCode(); - } + @Override + public int hashCode() { + return value.hashCode(); } + } - public static class Variable extends Term { - final String value; + public static final class Variable extends Term { + final String value; - public Variable(String value) { - this.value = value; - } + public Variable(String value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Variable(symbols.insert(this.value)); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Variable(symbolTable.insert(this.value)); + } - public String getValue() { - return value; - } + public String getValue() { + return value; + } - @Override - public String toString() { - return "$"+value; - } + @Override + public String toString() { + return "$" + value; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Variable variable = (Variable) o; + Variable variable = (Variable) o; - return value.equals(variable.value); - } + return value.equals(variable.value); + } - @Override - public int hashCode() { - return value.hashCode(); - } + @Override + public int hashCode() { + return value.hashCode(); } + } - public static class Integer extends Term { - final long value; + public static final class Integer extends Term { + final long value; - public Integer(long value) { - this.value = value; - } + public Integer(long value) { + this.value = value; + } - public long getValue() { - return value; - } + public long getValue() { + return value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Integer(this.value); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Integer(this.value); + } - @Override - public String toString() { - return String.valueOf(value); - } + @Override + public String toString() { + return String.valueOf(value); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Integer integer = (Integer) o; + Integer integer = (Integer) o; - return value == integer.value; - } + return value == integer.value; + } - @Override - public int hashCode() { - return Long.hashCode(value); - } + @Override + public int hashCode() { + return Long.hashCode(value); } + } - public static class Bytes extends Term { - final byte[] value; + public static final class Bytes extends Term { + final byte[] value; - public Bytes(byte[] value) { - this.value = value; - } + public Bytes(byte[] value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Bytes(this.value); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Bytes(this.value); + } - public byte[] getValue() { - return Arrays.copyOf(value, value.length); - } + public byte[] getValue() { + return Arrays.copyOf(value, value.length); + } - @Override - public String toString() { - return "hex:" + Utils.byteArrayToHexString(value).toLowerCase(); - } + @Override + public String toString() { + return "hex:" + Utils.byteArrayToHexString(value).toLowerCase(); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Bytes bytes = (Bytes) o; + Bytes bytes = (Bytes) o; - return Arrays.equals(value, bytes.value); - } + return Arrays.equals(value, bytes.value); + } - @Override - public int hashCode() { - return Arrays.hashCode(value); - } + @Override + public int hashCode() { + return Arrays.hashCode(value); } + } - public static class Date extends Term { - final long value; + public static final class Date extends Term { + final long value; - public Date(long value) { - this.value = value; - } + public Date(long value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Date(this.value); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Date(this.value); + } - public long getValue() { - return value; - } + public long getValue() { + return value; + } - @Override - public String toString() { - DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ISO_INSTANT; - return Instant.ofEpochSecond(value).atOffset(ZoneOffset.ofTotalSeconds(0)).format(dateTimeFormatter); - } + @Override + public String toString() { + DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ISO_INSTANT; + return Instant.ofEpochSecond(value) + .atOffset(ZoneOffset.ofTotalSeconds(0)) + .format(dateTimeFormatter); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Date date = (Date) o; + Date date = (Date) o; - return value == date.value; - } + return value == date.value; + } - @Override - public int hashCode() { - return Long.hashCode(value); - } + @Override + public int hashCode() { + return Long.hashCode(value); } + } - public static class Bool extends Term { - final boolean value; + public static final class Bool extends Term { + final boolean value; - public Bool(boolean value) { - this.value = value; - } + public Bool(boolean value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - return new org.biscuitsec.biscuit.datalog.Term.Bool(this.value); - } + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + return new org.biscuitsec.biscuit.datalog.Term.Bool(this.value); + } - public boolean getValue() { - return value; - } + public boolean getValue() { + return value; + } - @Override - public String toString() { - if(value) { - return "true"; - } else { - return "false"; - } - } + @Override + public String toString() { + if (value) { + return "true"; + } else { + return "false"; + } + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Bool bool = (Bool) o; + Bool bool = (Bool) o; - return value == bool.value; - } + return value == bool.value; + } - @Override - public int hashCode() { - return Boolean.hashCode(value); - } + @Override + public int hashCode() { + return Boolean.hashCode(value); } + } - public static class Set extends Term { - final java.util.Set value; + public static final class Set extends Term { + final java.util.Set value; - public Set(java.util.Set value) { - this.value = value; - } + public Set(java.util.Set value) { + this.value = value; + } - @Override - public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbols) { - HashSet s = new HashSet<>(); + @Override + public org.biscuitsec.biscuit.datalog.Term convert(SymbolTable symbolTable) { + HashSet s = new HashSet<>(); - for(Term t: this.value) { - s.add(t.convert(symbols)); - } + for (Term t : this.value) { + s.add(t.convert(symbolTable)); + } - return new org.biscuitsec.biscuit.datalog.Term.Set(s); - } + return new org.biscuitsec.biscuit.datalog.Term.Set(s); + } - public java.util.Set getValue() { - return Collections.unmodifiableSet(value); - } + public java.util.Set getValue() { + return Collections.unmodifiableSet(value); + } - @Override - public String toString() { - return value.toString(); - } + @Override + public String toString() { + return value.toString(); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } - Set set = (Set) o; + Set set = (Set) o; - return Objects.equals(value, set.value); - } + return Objects.equals(value, set.value); + } - @Override - public int hashCode() { - return value != null ? value.hashCode() : 0; - } + @Override + public int hashCode() { + return value != null ? value.hashCode() : 0; } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/Utils.java b/src/main/java/org/biscuitsec/biscuit/token/builder/Utils.java index 9ab879cb..a25d32fc 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/Utils.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/Utils.java @@ -1,81 +1,85 @@ package org.biscuitsec.biscuit.token.builder; -import org.biscuitsec.biscuit.error.Error; +import static org.biscuitsec.biscuit.datalog.Check.Kind.ONE; import java.util.ArrayList; import java.util.Date; import java.util.HashSet; import java.util.List; +import org.biscuitsec.biscuit.error.Error; -import static org.biscuitsec.biscuit.datalog.Check.Kind.One; - -public class Utils { - public static Fact fact(String name, List ids) throws Error.Language { - return new Fact(name, ids); - } - - public static Predicate pred(String name, List ids) { - return new Predicate(name, ids); - } - - public static Rule rule(String head_name, List head_ids, - List predicates) { - return new Rule(pred(head_name, head_ids), predicates, new ArrayList<>(), new ArrayList<>()); - } - - public static Rule constrained_rule(String head_name, List head_ids, - List predicates, - List expressions) { - return new Rule(pred(head_name, head_ids), predicates, expressions, new ArrayList<>()); - } - - public static Check check(Rule rule) { - return new Check(One,rule); - } - - public static Term integer(long i) { - return new Term.Integer(i); - } - - public static Term string(String s) { - return new Term.Str(s); - } - - public static Term s(String str) { - return new Term.Str(str); - } - - public static Term date(Date d) { - return new Term.Date(d.getTime() / 1000); - } - - public static Term var(String name) { - return new Term.Variable(name); - } - - public static Term set(HashSet s) { - return new Term.Set(s); - } - - public static final char[] HEX_ARRAY = "0123456789ABCDEF".toCharArray(); - public static String byteArrayToHexString(byte[] bytes) { - char[] hexChars = new char[bytes.length * 2]; - for (int j = 0; j < bytes.length; j++) { - int v = bytes[j] & 0xFF; - hexChars[j * 2] = HEX_ARRAY[v >>> 4]; - hexChars[j * 2 + 1] = HEX_ARRAY[v & 0x0F]; - } - return new String(hexChars); +public final class Utils { + private Utils() {} + + public static Fact fact(String name, List ids) throws Error.Language { + return new Fact(name, ids); + } + + public static Predicate pred(String name, List ids) { + return new Predicate(name, ids); + } + + public static Rule rule(String headName, List headIds, List predicates) { + return new Rule(pred(headName, headIds), predicates, new ArrayList<>(), new ArrayList<>()); + } + + public static Rule constrainedRule( + String headName, + List headIds, + List predicates, + List expressions) { + return new Rule(pred(headName, headIds), predicates, expressions, new ArrayList<>()); + } + + public static Check check(Rule rule) { + return new Check(ONE, rule); + } + + public static Term integer(long i) { + return new Term.Integer(i); + } + + public static Term string(String s) { + return new Term.Str(s); + } + + public static Term str(String str) { + return new Term.Str(str); + } + + public static Term date(Date d) { + return new Term.Date(d.getTime() / 1000); + } + + public static Term var(String name) { + return new Term.Variable(name); + } + + public static Term set(HashSet s) { + return new Term.Set(s); + } + + public static final char[] HEX_ARRAY = "0123456789ABCDEF".toCharArray(); + + public static String byteArrayToHexString(byte[] bytes) { + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = HEX_ARRAY[v >>> 4]; + hexChars[j * 2 + 1] = HEX_ARRAY[v & 0x0F]; } - - public static byte[] hexStringToByteArray(String hex) { - hex = hex.toUpperCase(); - int l = hex.length(); - byte[] data = new byte[l/2]; - for (int i = 0; i < l; i += 2) { - data[i/2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) - + Character.digit(hex.charAt(i+1), 16)); - } - return data; + return new String(hexChars); + } + + public static byte[] hexStringToByteArray(String hex) { + hex = hex.toUpperCase(); + int l = hex.length(); + byte[] data = new byte[l / 2]; + for (int i = 0; i < l; i += 2) { + data[i / 2] = + (byte) + ((Character.digit(hex.charAt(i), 16) << 4) + Character.digit(hex.charAt(i + 1), 16)); } + return data; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/package-info.java b/src/main/java/org/biscuitsec/biscuit/token/builder/package-info.java index 4b4bae49..8b4414ff 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/package-info.java @@ -1,4 +1,2 @@ -/** - * Builder interface to create tokens and caveats - */ -package org.biscuitsec.biscuit.token.builder; \ No newline at end of file +/** Builder interface to create tokens and caveats */ +package org.biscuitsec.biscuit.token.builder; diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Error.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Error.java index f5e5ee32..23da6464 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Error.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Error.java @@ -3,46 +3,49 @@ import com.google.gson.JsonElement; import com.google.gson.JsonObject; -public class Error extends Exception { - String input; - String message; - - public Error(String input, String message) { - super(message); - this.input = input; - this.message = message; +public final class Error extends Exception { + String input; + String message; + + public Error(String input, String message) { + super(message); + this.input = input; + this.message = message; + } + + @Override + public String toString() { + return "Error{" + "input='" + input + '\'' + ", message='" + message + '\'' + '}'; + } + + public JsonElement toJson() { + JsonObject jo = new JsonObject(); + jo.addProperty("input", this.input); + jo.addProperty("message", this.message); + return jo; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; } - - @Override - public String toString() { - return "Error{" + - "input='" + input + '\'' + - ", message='" + message + '\'' + - '}'; - } - - public JsonElement toJson(){ - JsonObject jo = new JsonObject(); - jo.addProperty("input",this.input); - jo.addProperty("message", this.message); - return jo; + if (o == null || getClass() != o.getClass()) { + return false; } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Error error = (Error) o; - - if (input != null ? !input.equals(error.input) : error.input != null) return false; - return message != null ? message.equals(error.message) : error.message == null; - } + Error error = (Error) o; - @Override - public int hashCode() { - int result = input != null ? input.hashCode() : 0; - result = 31 * result + (message != null ? message.hashCode() : 0); - return result; + if (input != null ? !input.equals(error.input) : error.input != null) { + return false; } + return message != null ? message.equals(error.message) : error.message == null; + } + + @Override + public int hashCode() { + int result = input != null ? input.hashCode() : 0; + result = 31 * result + (message != null ? message.hashCode() : 0); + return result; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java index cf76c1e8..4902bed8 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/ExpressionParser.java @@ -1,656 +1,652 @@ package org.biscuitsec.biscuit.token.builder.parser; -import org.biscuitsec.biscuit.token.builder.Term; -import io.vavr.Tuple2; -import io.vavr.control.Either; -import org.biscuitsec.biscuit.token.builder.Expression; - import static org.biscuitsec.biscuit.token.builder.parser.Parser.space; import static org.biscuitsec.biscuit.token.builder.parser.Parser.term; -public class ExpressionParser { - public static Either> parse(String s) { - return expr(space(s)); - } - - // Top-lever parser for an expression. Expression parsers are layered in - // order to support operator precedence (see https://en.wikipedia.org/wiki/Operator-precedence_parser). - // - // See https://github.com/biscuit-auth/biscuit/blob/master/SPECIFICATIONS.md#grammar - // for the precedence order of operators in biscuit datalog. - // - // The operators with the lowest precedence are parsed at the outer level, - // and their operands delegate to parsers that progressively handle more - // tightly binding operators. - // - // This level handles the last operator in the precedence list: `||` - // `||` is left associative, so multiple `||` expressions can be combined: - // `a || b || c <=> (a || b) || c` - public static Either> expr(String s) { - Either> res1 = expr1(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); +import io.vavr.Tuple2; +import io.vavr.control.Either; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.Term; - s = t1._1; - Expression e = t1._2; +public final class ExpressionParser { + private ExpressionParser() {} + + public static Either> parse(String s) { + return expr(space(s)); + } + + // Top-lever parser for an expression. Expression parsers are layered in + // order to support operator precedence (see + // https://en.wikipedia.org/wiki/Operator-precedence_parser). + // + // See https://github.com/biscuit-auth/biscuit/blob/master/SPECIFICATIONS.md#grammar + // for the precedence order of operators in biscuit datalog. + // + // The operators with the lowest precedence are parsed at the outer level, + // and their operands delegate to parsers that progressively handle more + // tightly binding operators. + // + // This level handles the last operator in the precedence list: `||` + // `||` is left associative, so multiple `||` expressions can be combined: + // `a || b || c <=> (a || b) || c` + public static Either> expr(String s) { + Either> res1 = expr1(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); + + s = t1._1; + Expression e = t1._2; + + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } + + Either> res2 = binaryOp0(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; + + s = space(s); + + Either> res3 = expr1(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); + + s = t3._1; + Expression e2 = t3._2; + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); + } - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } + return Either.right(new Tuple2<>(s, e)); + } - Either> res2 = binary_op0(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + /// This level handles `&&` + /// `&&` is left associative, so multiple `&&` expressions can be combined: + /// `a && b && c <=> (a && b) && c` + public static Either> expr1(String s) { + Either> res1 = expr2(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); + + s = t1._1; + Expression e = t1._2; + + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } + + Either> res2 = binaryOp1(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; + + s = space(s); + + Either> res3 = expr2(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); + + s = t3._1; + Expression e2 = t3._2; + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); + } - s = space(s); + return Either.right(new Tuple2<>(s, e)); + } - Either> res3 = expr1(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + /// This level handles comparison operators (`==`, `>`, `>=`, `<`, `<=`). + /// Those operators are _not_ associative and require explicit grouping + /// with parentheses. + public static Either> expr2(String s) { + Either> res1 = expr3(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - s = t3._1; - Expression e2 = t3._2; + s = t1._1; - e = new Expression.Binary(op, e, e2); - } + s = space(s); - return Either.right(new Tuple2<>(s, e)); + Either> res2 = binaryOp2(s); + if (res2.isLeft()) { + return Either.right(t1); } + Tuple2 t2 = res2.get(); + s = t2._1; - /// This level handles `&&` - /// `&&` is left associative, so multiple `&&` expressions can be combined: - /// `a && b && c <=> (a && b) && c` - public static Either> expr1(String s) { - Either> res1 = expr2(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); + s = space(s); - s = t1._1; - Expression e = t1._2; + Either> res3 = expr3(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); + + s = t3._1; + Expression e2 = t3._2; + Expression.Op op = t2._2; + Expression e = t1._2; + e = new Expression.Binary(op, e, e2); + + return Either.right(new Tuple2<>(s, e)); + } + + /// This level handles `|`. + /// It is left associative, so multiple expressions can be combined: + /// `a | b | c <=> (a | b) | c` + public static Either> expr3(String s) { + Either> res1 = expr4(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } + s = t1._1; + Expression e = t1._2; - Either> res2 = binary_op1(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } - s = space(s); + Either> res2 = binaryOp3(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; - Either> res3 = expr2(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + s = space(s); - s = t3._1; - Expression e2 = t3._2; + Either> res3 = expr4(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); - e = new Expression.Binary(op, e, e2); - } + s = t3._1; + Expression e2 = t3._2; - return Either.right(new Tuple2<>(s, e)); + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); } - /// This level handles comparison operators (`==`, `>`, `>=`, `<`, `<=`). - /// Those operators are _not_ associative and require explicit grouping - /// with parentheses. - public static Either> expr2(String s) { - Either> res1 = expr3(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); + return Either.right(new Tuple2<>(s, e)); + } - s = t1._1; - Expression e = t1._2; + /// This level handles `^`. + /// It is left associative, so multiple expressions can be combined: + /// `a ^ b ^ c <=> (a ^ b) ^ c` + public static Either> expr4(String s) { + Either> res1 = expr5(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - s = space(s); + s = t1._1; + Expression e = t1._2; - Either> res2 = binary_op2(s); - if (res2.isLeft()) { - return Either.right(t1); + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + Either> res2 = binaryOp4(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; - s = space(s); + s = space(s); - Either> res3 = expr3(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + Either> res3 = expr5(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); - s = t3._1; - Expression e2 = t3._2; + s = t3._1; + Expression e2 = t3._2; - e = new Expression.Binary(op, e, e2); - - return Either.right(new Tuple2<>(s, e)); + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); } - /// This level handles `|`. - /// It is left associative, so multiple expressions can be combined: - /// `a | b | c <=> (a | b) | c` - public static Either> expr3(String s) { - Either> res1 = expr4(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); + return Either.right(new Tuple2<>(s, e)); + } - s = t1._1; - Expression e = t1._2; + /// This level handles `&`. + /// It is left associative, so multiple expressions can be combined: + /// `a & b & c <=> (a & b) & c` + public static Either> expr5(String s) { + Either> res1 = expr6(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } + s = t1._1; + Expression e = t1._2; - Either> res2 = binary_op3(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } - s = space(s); + Either> res2 = binaryOp5(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; - Either> res3 = expr4(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + s = space(s); - s = t3._1; - Expression e2 = t3._2; + Either> res3 = expr6(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); - e = new Expression.Binary(op, e, e2); - } + s = t3._1; + Expression e2 = t3._2; - return Either.right(new Tuple2<>(s, e)); + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); } - /// This level handles `^`. - /// It is left associative, so multiple expressions can be combined: - /// `a ^ b ^ c <=> (a ^ b) ^ c` - public static Either> expr4(String s) { - Either> res1 = expr5(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); + return Either.right(new Tuple2<>(s, e)); + } - s = t1._1; - Expression e = t1._2; + /// This level handles `+` and `-`. + /// They are left associative, so multiple expressions can be combined: + /// `a + b - c <=> (a + b) - c` + public static Either> expr6(String s) { + Either> res1 = expr7(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } + s = t1._1; + Expression e = t1._2; - Either> res2 = binary_op4(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } - s = space(s); + Either> res2 = binaryOp6(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; - Either> res3 = expr5(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + s = space(s); - s = t3._1; - Expression e2 = t3._2; + Either> res3 = expr7(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); - e = new Expression.Binary(op, e, e2); - } + s = t3._1; + Expression e2 = t3._2; - return Either.right(new Tuple2<>(s, e)); + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); } - /// This level handles `&`. - /// It is left associative, so multiple expressions can be combined: - /// `a & b & c <=> (a & b) & c` - public static Either> expr5(String s) { - Either> res1 = expr6(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); + return Either.right(new Tuple2<>(s, e)); + } - s = t1._1; - Expression e = t1._2; + /// This level handles `*` and `/`. + /// They are left associative, so multiple expressions can be combined: + /// `a * b / c <=> (a * b) / c` + public static Either> expr7(String s) { + Either> res1 = expr8(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } + s = t1._1; + Expression e = t1._2; - Either> res2 = binary_op5(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + while (true) { + s = space(s); + if (s.length() == 0) { + break; + } - s = space(s); + Either> res2 = binaryOp7(s); + if (res2.isLeft()) { + break; + } + Tuple2 t2 = res2.get(); + s = t2._1; - Either> res3 = expr6(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + s = space(s); - s = t3._1; - Expression e2 = t3._2; + Either> res3 = expr8(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } + Tuple2 t3 = res3.get(); - e = new Expression.Binary(op, e, e2); - } + s = t3._1; + Expression e2 = t3._2; - return Either.right(new Tuple2<>(s, e)); + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); } - /// This level handles `+` and `-`. - /// They are left associative, so multiple expressions can be combined: - /// `a + b - c <=> (a + b) - c` - public static Either> expr6(String s) { - Either> res1 = expr7(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); - - s = t1._1; - Expression e = t1._2; - - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } - - Either> res2 = binary_op6(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + return Either.right(new Tuple2<>(s, e)); + } - s = space(s); + /// This level handles `!` (prefix negation) + public static Either> expr8(String s) { - Either> res3 = expr7(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + s = space(s); - s = t3._1; - Expression e2 = t3._2; + if (s.startsWith("!")) { + s = space(s.substring(1)); - e = new Expression.Binary(op, e, e2); - } + Either> res = expr9(s); + if (res.isLeft()) { + return Either.left(res.getLeft()); + } - return Either.right(new Tuple2<>(s, e)); + Tuple2 t = res.get(); + return Either.right(new Tuple2<>(t._1, new Expression.Unary(Expression.Op.Negate, t._2))); + } else { + return expr9(s); } + } + + /// This level handles methods. Methods can take either zero or one + /// argument in addition to the expression they are called on. + /// The name of the method decides its arity. + public static Either> expr9(String s) { + Either> res1 = exprTerm(s); + if (res1.isLeft()) { + return Either.left(res1.getLeft()); + } + Tuple2 t1 = res1.get(); - /// This level handles `*` and `/`. - /// They are left associative, so multiple expressions can be combined: - /// `a * b / c <=> (a * b) / c` - public static Either> expr7(String s) { - Either> res1 = expr8(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); - - s = t1._1; - Expression e = t1._2; - - while(true) { - s = space(s); - if(s.length() == 0) { - break; - } - - Either> res2 = binary_op7(s); - if (res2.isLeft()) { - break; - } - Tuple2 t2 = res2.get(); - s = t2._1; - Expression.Op op = t2._2; + s = t1._1; + Expression e = t1._2; - s = space(s); + while (true) { + s = space(s); + if (s.isEmpty()) { + break; + } - Either> res3 = expr8(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - Tuple2 t3 = res3.get(); + if (!s.startsWith(".")) { + return Either.right(new Tuple2<>(s, e)); + } - s = t3._1; - Expression e2 = t3._2; + s = s.substring(1); + Either> res2 = binaryOp8(s); + if (!res2.isLeft()) { + Tuple2 t2 = res2.get(); + s = space(t2._1); - e = new Expression.Binary(op, e, e2); + if (!s.startsWith("(")) { + return Either.left(new Error(s, "missing (")); } - return Either.right(new Tuple2<>(s, e)); - } + s = space(s.substring(1)); - /// This level handles `!` (prefix negation) - public static Either> expr8(String s) { - - s = space(s); - - if(s.startsWith("!")) { - s = space(s.substring(1)); + Either> res3 = expr(s); + if (res3.isLeft()) { + return Either.left(res3.getLeft()); + } - Either> res = expr9(s); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } + Tuple2 t3 = res3.get(); - Tuple2 t = res.get(); - return Either.right(new Tuple2<>(t._1, new Expression.Unary(Expression.Op.Negate, t._2))); - } else { - return expr9(s); + s = space(t3._1); + if (!s.startsWith(")")) { + return Either.left(new Error(s, "missing )")); } - } + s = space(s.substring(1)); + Expression e2 = t3._2; - /// This level handles methods. Methods can take either zero or one - /// argument in addition to the expression they are called on. - /// The name of the method decides its arity. - public static Either> expr9(String s) { - Either> res1 = expr_term(s); - if (res1.isLeft()) { - return Either.left(res1.getLeft()); - } - Tuple2 t1 = res1.get(); - - s = t1._1; - Expression e = t1._2; - - while(true) { - s = space(s); - if(s.isEmpty()) { - break; - } - - if (!s.startsWith(".")) { - return Either.right(new Tuple2<>(s, e)); - } - - s = s.substring(1); - Either> res2 = binary_op8(s); - if (!res2.isLeft()) { - Tuple2 t2 = res2.get(); - s = space(t2._1); - Expression.Op op = t2._2; - - if (!s.startsWith("(")) { - return Either.left(new Error(s, "missing (")); - } - - s = space(s.substring(1)); - - Either> res3 = expr(s); - if (res3.isLeft()) { - return Either.left(res3.getLeft()); - } - - Tuple2 t3 = res3.get(); - - s = space(t3._1); - if (!s.startsWith(")")) { - return Either.left(new Error(s, "missing )")); - } - s = space(s.substring(1)); - Expression e2 = t3._2; - - e = new Expression.Binary(op, e, e2); - } else { - if (s.startsWith("length()")) { - e = new Expression.Unary(Expression.Op.Length, e); - s = s.substring(9); - } - } + Expression.Op op = t2._2; + e = new Expression.Binary(op, e, e2); + } else { + if (s.startsWith("length()")) { + e = new Expression.Unary(Expression.Op.Length, e); + s = s.substring(9); } - - return Either.right(new Tuple2<>(s, e)); + } } - public static Either> expr_term(String s) { - Either> res1 = unary_parens(s); - if (res1.isRight()) { - return res1; - } + return Either.right(new Tuple2<>(s, e)); + } - Either> res2 = term(s); - if (res2.isLeft()) { - return Either.left(res2.getLeft()); - } - Tuple2 t2 = res2.get(); - Expression e = new Expression.Value(t2._2); + public static Either> exprTerm(String s) { + Either> res1 = unaryParens(s); + if (res1.isRight()) { + return res1; + } - return Either.right(new Tuple2<>(t2._1, e)); + Either> res2 = term(s); + if (res2.isLeft()) { + return Either.left(res2.getLeft()); } + Tuple2 t2 = res2.get(); + Expression e = new Expression.Value(t2._2); - public static Either> unary(String s) { - s = space(s); + return Either.right(new Tuple2<>(t2._1, e)); + } - if(s.startsWith("!")) { - s = space(s.substring(1)); + public static Either> unary(String s) { + s = space(s); - Either> res = expr(s); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } + if (s.startsWith("!")) { + s = space(s.substring(1)); - Tuple2 t = res.get(); - return Either.right(new Tuple2<>(t._1, new Expression.Unary(Expression.Op.Negate, t._2))); - } + Either> res = expr(s); + if (res.isLeft()) { + return Either.left(res.getLeft()); + } + Tuple2 t = res.get(); + return Either.right(new Tuple2<>(t._1, new Expression.Unary(Expression.Op.Negate, t._2))); + } - if(s.startsWith("(")) { - Either> res = unary_parens(s); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } + if (s.startsWith("(")) { + Either> res = unaryParens(s); + if (res.isLeft()) { + return Either.left(res.getLeft()); + } - Tuple2 t = res.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Tuple2 t = res.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Expression e; - Either> res = term(s); - if (res.isRight()) { - Tuple2 t = res.get(); - s = space(t._1); - e = new Expression.Value(t._2); - } else { - Either> res2 = unary_parens(s); - if (res2.isLeft()) { - return Either.left(res2.getLeft()); - } - - Tuple2 t = res2.get(); - s = space(t._1); - e = t._2; - } + Expression e; + Either> res = term(s); + if (res.isRight()) { + Tuple2 t = res.get(); + s = space(t._1); + e = new Expression.Value(t._2); + } else { + Either> res2 = unaryParens(s); + if (res2.isLeft()) { + return Either.left(res2.getLeft()); + } + + Tuple2 t = res2.get(); + s = space(t._1); + e = t._2; + } - if(s.startsWith(".length()")) { - s = space(s.substring(9)); - return Either.right(new Tuple2<>(s, new Expression.Unary(Expression.Op.Length, e))); - } else { - return Either.left(new Error(s, "unexpected token")); - } + if (s.startsWith(".length()")) { + s = space(s.substring(9)); + return Either.right(new Tuple2<>(s, new Expression.Unary(Expression.Op.Length, e))); + } else { + return Either.left(new Error(s, "unexpected token")); } + } - public static Either> unary_parens(String s) { - if(s.startsWith("(")) { - s = space(s.substring(1)); + public static Either> unaryParens(String s) { + if (s.startsWith("(")) { + s = space(s.substring(1)); - Either> res = expr(s); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } + Either> res = expr(s); + if (res.isLeft()) { + return Either.left(res.getLeft()); + } - Tuple2 t = res.get(); + Tuple2 t = res.get(); - s = space(t._1); - if(!s.startsWith(")")) { - return Either.left(new Error(s, "missing )")); - } + s = space(t._1); + if (!s.startsWith(")")) { + return Either.left(new Error(s, "missing )")); + } - s = space(s.substring(1)); - return Either.right(new Tuple2<>(s, new Expression.Unary(Expression.Op.Parens, t._2))); - } else { - return Either.left(new Error(s, "missing (")); - } + s = space(s.substring(1)); + return Either.right(new Tuple2<>(s, new Expression.Unary(Expression.Op.Parens, t._2))); + } else { + return Either.left(new Error(s, "missing (")); } + } - public static Either> binary_op0(String s) { - if(s.startsWith("||")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.Or)); - } - - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp0(String s) { + if (s.startsWith("||")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.Or)); } - public static Either> binary_op1(String s) { - if(s.startsWith("&&")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.And)); - } + return Either.left(new Error(s, "unrecognized op")); + } - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp1(String s) { + if (s.startsWith("&&")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.And)); } - public static Either> binary_op2(String s) { - if(s.startsWith("<=")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.LessOrEqual)); - } - if(s.startsWith(">=")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.GreaterOrEqual)); - } - if(s.startsWith("<")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.LessThan)); - } - if(s.startsWith(">")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.GreaterThan)); - } - if(s.startsWith("==")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.Equal)); - } - if(s.startsWith("!=")) { - return Either.right(new Tuple2<>(s.substring(2), Expression.Op.NotEqual)); - } + return Either.left(new Error(s, "unrecognized op")); + } - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp2(String s) { + if (s.startsWith("<=")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.LessOrEqual)); + } + if (s.startsWith(">=")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.GreaterOrEqual)); + } + if (s.startsWith("<")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.LessThan)); + } + if (s.startsWith(">")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.GreaterThan)); + } + if (s.startsWith("==")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.Equal)); + } + if (s.startsWith("!=")) { + return Either.right(new Tuple2<>(s.substring(2), Expression.Op.NotEqual)); } + return Either.left(new Error(s, "unrecognized op")); + } - public static Either> binary_op3(String s) { - if(s.startsWith("^")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseXor)); - } - - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp3(String s) { + if (s.startsWith("^")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseXor)); } - public static Either> binary_op4(String s) { - if(s.startsWith("|") && !s.startsWith("||")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseOr)); - } + return Either.left(new Error(s, "unrecognized op")); + } - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp4(String s) { + if (s.startsWith("|") && !s.startsWith("||")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseOr)); } - public static Either> binary_op5(String s) { - if(s.startsWith("&") && !s.startsWith("&&")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseAnd)); - } + return Either.left(new Error(s, "unrecognized op")); + } - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp5(String s) { + if (s.startsWith("&") && !s.startsWith("&&")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.BitwiseAnd)); } - public static Either> binary_op6(String s) { + return Either.left(new Error(s, "unrecognized op")); + } - if(s.startsWith("+")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Add)); - } - if(s.startsWith("-")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Sub)); - } + public static Either> binaryOp6(String s) { - return Either.left(new Error(s, "unrecognized op")); + if (s.startsWith("+")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Add)); + } + if (s.startsWith("-")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Sub)); } + return Either.left(new Error(s, "unrecognized op")); + } - public static Either> binary_op7(String s) { - if(s.startsWith("*")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Mul)); - } - if(s.startsWith("/")) { - return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Div)); - } - - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp7(String s) { + if (s.startsWith("*")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Mul)); + } + if (s.startsWith("/")) { + return Either.right(new Tuple2<>(s.substring(1), Expression.Op.Div)); } - public static Either> binary_op8(String s) { - if(s.startsWith("intersection")) { - return Either.right(new Tuple2<>(s.substring(12), Expression.Op.Intersection)); - } - if(s.startsWith("union")) { - return Either.right(new Tuple2<>(s.substring(5), Expression.Op.Union)); - } - if(s.startsWith("contains")) { - return Either.right(new Tuple2<>(s.substring(8), Expression.Op.Contains)); - } - if(s.startsWith("starts_with")) { - return Either.right(new Tuple2<>(s.substring(11), Expression.Op.Prefix)); - } - if(s.startsWith("ends_with")) { - return Either.right(new Tuple2<>(s.substring(9), Expression.Op.Suffix)); - } - if(s.startsWith("matches")) { - return Either.right(new Tuple2<>(s.substring(7), Expression.Op.Regex)); - } + return Either.left(new Error(s, "unrecognized op")); + } - return Either.left(new Error(s, "unrecognized op")); + public static Either> binaryOp8(String s) { + if (s.startsWith("intersection")) { + return Either.right(new Tuple2<>(s.substring(12), Expression.Op.Intersection)); + } + if (s.startsWith("union")) { + return Either.right(new Tuple2<>(s.substring(5), Expression.Op.Union)); + } + if (s.startsWith("contains")) { + return Either.right(new Tuple2<>(s.substring(8), Expression.Op.Contains)); + } + if (s.startsWith("starts_with")) { + return Either.right(new Tuple2<>(s.substring(11), Expression.Op.Prefix)); } + if (s.startsWith("ends_with")) { + return Either.right(new Tuple2<>(s.substring(9), Expression.Op.Suffix)); + } + if (s.startsWith("matches")) { + return Either.right(new Tuple2<>(s.substring(7), Expression.Op.Regex)); + } + + return Either.left(new Error(s, "unrecognized op")); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java index 1722ba7f..d5b3dfa9 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java +++ b/src/main/java/org/biscuitsec/biscuit/token/builder/parser/Parser.java @@ -1,802 +1,845 @@ package org.biscuitsec.biscuit.token.builder.parser; import biscuit.format.schema.Schema; -import io.vavr.collection.Stream; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.token.Policy; import io.vavr.Tuple2; import io.vavr.Tuple4; +import io.vavr.collection.Stream; import io.vavr.control.Either; -import org.biscuitsec.biscuit.token.builder.*; - import java.time.OffsetDateTime; import java.time.format.DateTimeParseException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.function.Function; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.token.Policy; +import org.biscuitsec.biscuit.token.builder.Block; +import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.Fact; +import org.biscuitsec.biscuit.token.builder.Predicate; +import org.biscuitsec.biscuit.token.builder.Rule; +import org.biscuitsec.biscuit.token.builder.Scope; +import org.biscuitsec.biscuit.token.builder.Term; +import org.biscuitsec.biscuit.token.builder.Utils; + +public final class Parser { + private Parser() {} + + /** + * Takes a datalog string with \n as datalog line separator. It tries to parse each + * line using fact, rule, check and scope sequentially. + * + *

If one succeeds it returns Right(Block) else it returns a Map[lineNumber, List[Error]] + * + * @param index block index + * @param s datalog string to parse + * @return Either>, Block> + */ + public static Either>, Block> datalog(long index, String s) { + Block blockBuilder = new Block(); + + // empty block code + if (s.isEmpty()) { + return Either.right(blockBuilder); + } -public class Parser { - /** - * Takes a datalog string with \n as datalog line separator. It tries to parse - * each line using fact, rule, check and scope sequentially. - * - * If one succeeds it returns Right(Block) - * else it returns a Map[lineNumber, List[Error]] - * - * @param index block index - * @param s datalog string to parse - * @return Either>, Block> - */ - public static Either>, Block> datalog(long index, String s) { - Block blockBuilder = new Block(); - - // empty block code - if (s.isEmpty()) { - return Either.right(blockBuilder); - } - - Map> errors = new HashMap<>(); - - s = removeCommentsAndWhitespaces(s); - String[] codeLines = s.split(";"); - - Stream.of(codeLines) - .zipWithIndex() - .forEach(indexedLine -> { - String code = indexedLine._1.strip(); - - if (!code.isEmpty()) { - int lineNumber = indexedLine._2; - List lineErrors = new ArrayList<>(); - - boolean parsed = false; - parsed = rule(code).fold(e -> { - lineErrors.add(e); - return false; - }, r -> { - blockBuilder.add_rule(r._2); - return true; - }); - - if (!parsed) { - parsed = fact(code).fold(e -> { - lineErrors.add(e); - return false; - }, r -> { - blockBuilder.add_fact(r._2); - return true; - }); - } - - if (!parsed) { - parsed = check(code).fold(e -> { - lineErrors.add(e); - return false; - }, r -> { - blockBuilder.add_check(r._2); - return true; - }); - } - - if (!parsed) { - parsed = scope(code).fold(e -> { - lineErrors.add(e); - return false; - }, r -> { - blockBuilder.add_scope(r._2); - return true; - }); - } - - if (!parsed) { - lineErrors.forEach(System.out::println); - errors.put(lineNumber, lineErrors); - } - } - }); - - if (!errors.isEmpty()) { - return Either.left(errors); - } + Map> errors = new HashMap<>(); + + s = removeCommentsAndWhitespaces(s); + String[] codeLines = s.split(";"); + + Stream.of(codeLines) + .zipWithIndex() + .forEach( + indexedLine -> { + String code = indexedLine._1.strip(); + + if (!code.isEmpty()) { + List lineErrors = new ArrayList<>(); + + boolean parsed = false; + parsed = + rule(code) + .fold( + e -> { + lineErrors.add(e); + return false; + }, + r -> { + blockBuilder.addRule(r._2); + return true; + }); + + if (!parsed) { + parsed = + fact(code) + .fold( + e -> { + lineErrors.add(e); + return false; + }, + r -> { + blockBuilder.addFact(r._2); + return true; + }); + } - return Either.right(blockBuilder); - } + if (!parsed) { + parsed = + check(code) + .fold( + e -> { + lineErrors.add(e); + return false; + }, + r -> { + blockBuilder.addCheck(r._2); + return true; + }); + } - public static Either> fact(String s) { - Either> res = fact_predicate(s); - if (res.isLeft()) { - return Either.left(res.getLeft()); - } else { - Tuple2 t = res.get(); + if (!parsed) { + parsed = + scope(code) + .fold( + e -> { + lineErrors.add(e); + return false; + }, + r -> { + blockBuilder.addScope(r._2); + return true; + }); + } - if (!t._1.isEmpty()) { - return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); - } + if (!parsed) { + lineErrors.forEach(System.out::println); + int lineNumber = indexedLine._2; + errors.put(lineNumber, lineErrors); + } + } + }); - return Either.right(new Tuple2<>(t._1, new Fact(t._2))); - } + if (!errors.isEmpty()) { + return Either.left(errors); } - public static Either> rule(String s) { - Either> res0 = predicate(s); - if (res0.isLeft()) { - return Either.left(res0.getLeft()); - } + return Either.right(blockBuilder); + } - Tuple2 t0 = res0.get(); - s = t0._1; - Predicate head = t0._2; + public static Either> fact(String s) { + Either> res = factPredicate(s); + if (res.isLeft()) { + return Either.left(res.getLeft()); + } else { + Tuple2 t = res.get(); - s = space(s); - if (s.length() < 2 || s.charAt(0) != '<' || s.charAt(1) != '-') { - return Either.left(new Error(s, "rule arrow not found")); - } + if (!t._1.isEmpty()) { + return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); + } - List predicates = new ArrayList(); - s = s.substring(2); + return Either.right(new Tuple2<>(t._1, new Fact(t._2))); + } + } - Either, List, List>> bodyRes = rule_body(s); - if (bodyRes.isLeft()) { - return Either.left(bodyRes.getLeft()); - } + public static Either> rule(String s) { + Either> res0 = predicate(s); + if (res0.isLeft()) { + return Either.left(res0.getLeft()); + } - Tuple4, List, List> body = bodyRes.get(); + Tuple2 t0 = res0.get(); + s = t0._1; - if (!body._1.isEmpty()) { - return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + body._1)); - } + s = space(s); + if (s.length() < 2 || s.charAt(0) != '<' || s.charAt(1) != '-') { + return Either.left(new Error(s, "rule arrow not found")); + } - Rule rule = new Rule(head, body._2, body._3, body._4); - Either valid = rule.validate_variables(); - if (valid.isLeft()) { - return Either.left(new Error(s, valid.getLeft())); - } + List predicates = new ArrayList(); + s = s.substring(2); - return Either.right(new Tuple2<>(body._1, rule)); + Either, List, List>> bodyRes = + ruleBody(s); + if (bodyRes.isLeft()) { + return Either.left(bodyRes.getLeft()); } - public static Either> check(String s) { - org.biscuitsec.biscuit.datalog.Check.Kind kind; + Tuple4, List, List> body = bodyRes.get(); - if (s.startsWith("check if")) { - kind = org.biscuitsec.biscuit.datalog.Check.Kind.One; - s = s.substring("check if".length()); - } else if (s.startsWith("check all")) { - kind = org.biscuitsec.biscuit.datalog.Check.Kind.All; - s = s.substring("check all".length()); - } else { - return Either.left(new Error(s, "missing check prefix")); - } + if (!body._1.isEmpty()) { + return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + body._1)); + } - List queries = new ArrayList<>(); - Either>> bodyRes = check_body(s); - if (bodyRes.isLeft()) { - return Either.left(bodyRes.getLeft()); - } + Predicate head = t0._2; + Rule rule = new Rule(head, body._2, body._3, body._4); + Either valid = rule.validateVariables(); + if (valid.isLeft()) { + return Either.left(new Error(s, valid.getLeft())); + } - Tuple2> t = bodyRes.get(); + return Either.right(new Tuple2<>(body._1, rule)); + } - if (!t._1.isEmpty()) { - return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); - } + public static Either> check(String s) { + org.biscuitsec.biscuit.datalog.Check.Kind kind; - return Either.right(new Tuple2<>(t._1, new Check(kind, t._2))); + if (s.startsWith("check if")) { + kind = org.biscuitsec.biscuit.datalog.Check.Kind.ONE; + s = s.substring("check if".length()); + } else if (s.startsWith("check all")) { + kind = org.biscuitsec.biscuit.datalog.Check.Kind.ALL; + s = s.substring("check all".length()); + } else { + return Either.left(new Error(s, "missing check prefix")); } - public static Either> policy(String s) { - Policy.Kind p = Policy.Kind.Allow; + List queries = new ArrayList<>(); + Either>> bodyRes = checkBody(s); + if (bodyRes.isLeft()) { + return Either.left(bodyRes.getLeft()); + } - String allow = "allow if"; - String deny = "deny if"; - if (s.startsWith(allow)) { - s = s.substring(allow.length()); - } else if (s.startsWith(deny)) { - p = Policy.Kind.Deny; - s = s.substring(deny.length()); - } else { - return Either.left(new Error(s, "missing policy prefix")); - } + Tuple2> t = bodyRes.get(); - List queries = new ArrayList<>(); - Either>> bodyRes = check_body(s); - if (bodyRes.isLeft()) { - return Either.left(bodyRes.getLeft()); - } + if (!t._1.isEmpty()) { + return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); + } + + return Either.right(new Tuple2<>(t._1, new Check(kind, t._2))); + } + + public static Either> policy(String s) { + Policy.Kind p = Policy.Kind.ALLOW; + + String allow = "allow if"; + String deny = "deny if"; + if (s.startsWith(allow)) { + s = s.substring(allow.length()); + } else if (s.startsWith(deny)) { + p = Policy.Kind.DENY; + s = s.substring(deny.length()); + } else { + return Either.left(new Error(s, "missing policy prefix")); + } - Tuple2> t = bodyRes.get(); + List queries = new ArrayList<>(); + Either>> bodyRes = checkBody(s); + if (bodyRes.isLeft()) { + return Either.left(bodyRes.getLeft()); + } - if (!t._1.isEmpty()) { - return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); - } + Tuple2> t = bodyRes.get(); - return Either.right(new Tuple2<>(t._1, new Policy(t._2, p))); + if (!t._1.isEmpty()) { + return Either.left(new Error(s, "the string was not entirely parsed, remaining: " + t._1)); } - public static Either>> check_body(String s) { - List queries = new ArrayList<>(); - Either, List, List>> bodyRes = rule_body(s); - if (bodyRes.isLeft()) { - return Either.left(bodyRes.getLeft()); - } + return Either.right(new Tuple2<>(t._1, new Policy(t._2, p))); + } - Tuple4, List, List> body = bodyRes.get(); + public static Either>> checkBody(String s) { + List queries = new ArrayList<>(); + Either, List, List>> bodyRes = + ruleBody(s); + if (bodyRes.isLeft()) { + return Either.left(bodyRes.getLeft()); + } - s = body._1; - //FIXME: parse scopes - queries.add(new Rule(new Predicate("query", new ArrayList<>()), body._2, body._3, body._4)); + Tuple4, List, List> body = bodyRes.get(); - int i = 0; - while (true) { - if (s.length() == 0) { - break; - } + s = body._1; + // FIXME: parse scopes + queries.add(new Rule(new Predicate("query", new ArrayList<>()), body._2, body._3, body._4)); - s = space(s); + int i = 0; + while (true) { + if (s.length() == 0) { + break; + } - if (!s.startsWith("or")) { - break; - } - s = s.substring(2); + s = space(s); - Either, List, List>> bodyRes2 = rule_body(s); - if (bodyRes2.isLeft()) { - return Either.left(bodyRes2.getLeft()); - } + if (!s.startsWith("or")) { + break; + } + s = s.substring(2); - Tuple4, List, List> body2 = bodyRes2.get(); + Either, List, List>> bodyRes2 = + ruleBody(s); + if (bodyRes2.isLeft()) { + return Either.left(bodyRes2.getLeft()); + } - s = body2._1; - queries.add(new Rule(new Predicate("query", new ArrayList<>()), body2._2, body2._3, body2._4)); - } + Tuple4, List, List> body2 = bodyRes2.get(); - return Either.right(new Tuple2<>(s, queries)); - } - - public static Either, List, List>> rule_body(String s) { - List predicates = new ArrayList(); - List expressions = new ArrayList<>(); - - while (true) { - s = space(s); - - Either> res = predicate(s); - if (res.isRight()) { - Tuple2 t = res.get(); - s = t._1; - predicates.add(t._2); - } else { - Either> res2 = expression(s); - if (res2.isRight()) { - Tuple2 t2 = res2.get(); - s = t2._1; - expressions.add(t2._2); - } else { - break; - } - } + s = body2._1; + queries.add( + new Rule(new Predicate("query", new ArrayList<>()), body2._2, body2._3, body2._4)); + } - s = space(s); + return Either.right(new Tuple2<>(s, queries)); + } - if (s.length() == 0 || s.charAt(0) != ',') { - break; - } else { - s = s.substring(1); - } - } + public static Either, List, List>> + ruleBody(String s) { + List predicates = new ArrayList(); + List expressions = new ArrayList<>(); - Either>> res = scopes(s); - if(res.isLeft()) { - return Either.right(new Tuple4<>(s, predicates, expressions, new ArrayList<>())); - } else { - Tuple2> t = res.get(); - return Either.right(new Tuple4<>(t._1, predicates, expressions, t._2)); + while (true) { + s = space(s); + Either> res = predicate(s); + if (res.isRight()) { + Tuple2 t = res.get(); + s = t._1; + predicates.add(t._2); + } else { + Either> res2 = expression(s); + if (res2.isRight()) { + Tuple2 t2 = res2.get(); + s = t2._1; + expressions.add(t2._2); + } else { + break; } + } - } + s = space(s); - public static Either> predicate(String s) { - Tuple2 tn = take_while(s, (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_' || c == ':'); - String name = tn._1; - s = tn._2; - - s = space(s); - if (s.length() == 0 || s.charAt(0) != '(') { - return Either.left(new Error(s, "opening parens not found for predicate "+name)); - } + if (s.length() == 0 || s.charAt(0) != ',') { + break; + } else { s = s.substring(1); + } + } - List terms = new ArrayList(); - while (true) { + Either>> res = scopes(s); + if (res.isLeft()) { + return Either.right(new Tuple4<>(s, predicates, expressions, new ArrayList<>())); + } else { + Tuple2> t = res.get(); + return Either.right(new Tuple4<>(t._1, predicates, expressions, t._2)); + } + } + + public static Either> predicate(String s) { + Tuple2 tn = + takewhile( + s, (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_' || c == ':'); + String name = tn._1; + s = tn._2; + + s = space(s); + if (s.length() == 0 || s.charAt(0) != '(') { + return Either.left(new Error(s, "opening parens not found for predicate " + name)); + } + s = s.substring(1); - s = space(s); + List terms = new ArrayList(); + while (true) { - Either> res = term(s); - if (res.isLeft()) { - break; - } + s = space(s); - Tuple2 t = res.get(); - s = t._1; - terms.add(t._2); + Either> res = term(s); + if (res.isLeft()) { + break; + } - s = space(s); + Tuple2 t = res.get(); + s = t._1; + terms.add(t._2); - if (s.length() == 0 || s.charAt(0) != ',') { - break; - } else { - s = s.substring(1); - } - } + s = space(s); - s = space(s); - if (0 == s.length() || s.charAt(0) != ')') { - return Either.left(new Error(s, "closing parens not found")); - } - String remaining = s.substring(1); + if (s.length() == 0 || s.charAt(0) != ',') { + break; + } else { + s = s.substring(1); + } + } - return Either.right(new Tuple2(remaining, new Predicate(name, terms))); + s = space(s); + if (0 == s.length() || s.charAt(0) != ')') { + return Either.left(new Error(s, "closing parens not found")); } + String remaining = s.substring(1); - public static Either>> scopes(String s) { - if (!s.startsWith("trusting")) { - return Either.left(new Error(s, "missing scopes prefix")); - } - s = s.substring("trusting".length()); - s = space(s); + return Either.right(new Tuple2(remaining, new Predicate(name, terms))); + } - List scopes = new ArrayList(); + public static Either>> scopes(String s) { + if (!s.startsWith("trusting")) { + return Either.left(new Error(s, "missing scopes prefix")); + } + s = s.substring("trusting".length()); + s = space(s); - while (true) { - s = space(s); + List scopes = new ArrayList(); - Either> res = scope(s); - if (res.isLeft()) { - break; - } + while (true) { + s = space(s); - Tuple2 t = res.get(); - s = t._1; - scopes.add(t._2); + Either> res = scope(s); + if (res.isLeft()) { + break; + } - s = space(s); + Tuple2 t = res.get(); + s = t._1; + scopes.add(t._2); - if (s.length() == 0 || s.charAt(0) != ',') { - break; - } else { - s = s.substring(1); - } - } + s = space(s); - return Either.right(new Tuple2<>(s, scopes)); + if (s.length() == 0 || s.charAt(0) != ',') { + break; + } else { + s = s.substring(1); + } } - public static Either> scope(String s) { - if (s.startsWith("authority")) { - s = s.substring("authority".length()); - return Either.right(new Tuple2<>(s, Scope.authority())); - } - - if (s.startsWith("previous")) { - s = s.substring("previous".length()); - return Either.right(new Tuple2<>(s, Scope.previous())); - } + return Either.right(new Tuple2<>(s, scopes)); + } - if (0 < s.length() && s.charAt(0) == '{') { - String remaining = s.substring(1); - Either> res = name(remaining); - if (res.isLeft()) { - return Either.left(new Error(s, "unrecognized parameter")); - } - Tuple2 t = res.get(); - if (0 < s.length() && s.charAt(0) == '}') { - return Either.right(new Tuple2<>(t._1, Scope.parameter(t._2))); - } else { - return Either.left(new Error(s, "unrecognized parameter end")); - } - } - - Either> res2 = publicKey(s); - if (res2.isLeft()) { - return Either.left(new Error(s, "unrecognized public key")); - } - Tuple2 t = res2.get(); - return Either.right(new Tuple2<>(t._1, Scope.publicKey(t._2))); - } - - public static Either> publicKey(String s) { - if (s.startsWith("ed25519/")) { - s = s.substring("ed25519/".length()); - Tuple2 t = hex(s); - return Either.right(new Tuple2(t._1, new PublicKey(Schema.PublicKey.Algorithm.Ed25519, t._2))); - } else if (s.startsWith("secp256r1/")) { - s = s.substring("secp256r1/".length()); - Tuple2 t = hex(s); - return Either.right(new Tuple2(t._1, new PublicKey(Schema.PublicKey.Algorithm.SECP256R1, t._2))); - } else { - return Either.left(new Error(s, "unrecognized public key prefix")); - } + public static Either> scope(String s) { + if (s.startsWith("authority")) { + s = s.substring("authority".length()); + return Either.right(new Tuple2<>(s, Scope.authority())); } - public static Either> fact_predicate(String s) { - Tuple2 tn = take_while(s, (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_' || c == ':'); - String name = tn._1; - s = tn._2; - - s = space(s); - if (s.length() == 0 || s.charAt(0) != '(') { - return Either.left(new Error(s, "opening parens not found for fact "+name)); - } - s = s.substring(1); + if (s.startsWith("previous")) { + s = s.substring("previous".length()); + return Either.right(new Tuple2<>(s, Scope.previous())); + } - List terms = new ArrayList(); - while (true) { + if (0 < s.length() && s.charAt(0) == '{') { + String remaining = s.substring(1); + Either> res = name(remaining); + if (res.isLeft()) { + return Either.left(new Error(s, "unrecognized parameter")); + } + Tuple2 t = res.get(); + if (0 < s.length() && s.charAt(0) == '}') { + return Either.right(new Tuple2<>(t._1, Scope.parameter(t._2))); + } else { + return Either.left(new Error(s, "unrecognized parameter end")); + } + } - s = space(s); + Either> res2 = publicKey(s); + if (res2.isLeft()) { + return Either.left(new Error(s, "unrecognized public key")); + } + Tuple2 t = res2.get(); + return Either.right(new Tuple2<>(t._1, Scope.publicKey(t._2))); + } + + public static Either> publicKey(String s) { + if (s.startsWith("ed25519/")) { + s = s.substring("ed25519/".length()); + Tuple2 t = hex(s); + return Either.right( + new Tuple2(t._1, new PublicKey(Schema.PublicKey.Algorithm.Ed25519, t._2))); + } else if (s.startsWith("secp256r1/")) { + s = s.substring("secp256r1/".length()); + Tuple2 t = hex(s); + return Either.right( + new Tuple2(t._1, new PublicKey(Schema.PublicKey.Algorithm.SECP256R1, t._2))); + } else { + return Either.left(new Error(s, "unrecognized public key prefix")); + } + } + + public static Either> factPredicate(String s) { + Tuple2 tn = + takewhile( + s, (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_' || c == ':'); + String name = tn._1; + s = tn._2; + + s = space(s); + if (s.length() == 0 || s.charAt(0) != '(') { + return Either.left(new Error(s, "opening parens not found for fact " + name)); + } + s = s.substring(1); - Either> res = fact_term(s); - if (res.isLeft()) { - break; - } + List terms = new ArrayList(); + while (true) { - Tuple2 t = res.get(); - s = t._1; - terms.add(t._2); + s = space(s); - s = space(s); + Either> res = factTerm(s); + if (res.isLeft()) { + break; + } - if (s.length() == 0 || s.charAt(0) != ',') { - break; - } else { - s = s.substring(1); - } - } + Tuple2 t = res.get(); + s = t._1; + terms.add(t._2); - s = space(s); - if (0 == s.length() || s.charAt(0) != ')') { - return Either.left(new Error(s, "closing parens not found")); - } - String remaining = s.substring(1); + s = space(s); - return Either.right(new Tuple2(remaining, new Predicate(name, terms))); + if (s.length() == 0 || s.charAt(0) != ',') { + break; + } else { + s = s.substring(1); + } } - public static Either> name(String s) { - Tuple2 t = take_while(s, (c) -> Character.isAlphabetic(c) || c == '_'); - String name = t._1; - String remaining = t._2; - - return Either.right(new Tuple2(remaining, name)); + s = space(s); + if (0 == s.length() || s.charAt(0) != ')') { + return Either.left(new Error(s, "closing parens not found")); } + String remaining = s.substring(1); - public static Either> term(String s) { + return Either.right(new Tuple2(remaining, new Predicate(name, terms))); + } - Either> res5 = variable(s); - if (res5.isRight()) { - Tuple2 t = res5.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + public static Either> name(String s) { + Tuple2 t = takewhile(s, (c) -> Character.isAlphabetic(c) || c == '_'); + String name = t._1; + String remaining = t._2; - Either> res2 = string(s); - if (res2.isRight()) { - Tuple2 t = res2.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + return Either.right(new Tuple2(remaining, name)); + } - Either> res7 = set(s); - if (res7.isRight()) { - Tuple2 t = res7.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + public static Either> term(String s) { - Either> res6 = bool(s); - if (res6.isRight()) { - Tuple2 t = res6.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res5 = variable(s); + if (res5.isRight()) { + Tuple2 t = res5.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res4 = date(s); - if (res4.isRight()) { - Tuple2 t = res4.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res2 = string(s); + if (res2.isRight()) { + Tuple2 t = res2.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res3 = integer(s); - if (res3.isRight()) { - Tuple2 t = res3.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res7 = set(s); + if (res7.isRight()) { + Tuple2 t = res7.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res8 = bytes(s); - if (res8.isRight()) { - Tuple2 t = res8.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res6 = bool(s); + if (res6.isRight()) { + Tuple2 t = res6.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - return Either.left(new Error(s, "unrecognized value")); + Either> res4 = date(s); + if (res4.isRight()) { + Tuple2 t = res4.get(); + return Either.right(new Tuple2<>(t._1, t._2)); } - public static Either> fact_term(String s) { - if (s.length() > 0 && s.charAt(0) == '$') { - return Either.left(new Error(s, "variables are not allowed in facts")); - } + Either> res3 = integer(s); + if (res3.isRight()) { + Tuple2 t = res3.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res2 = string(s); - if (res2.isRight()) { - Tuple2 t = res2.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res8 = bytes(s); + if (res8.isRight()) { + Tuple2 t = res8.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res7 = set(s); - if (res7.isRight()) { - Tuple2 t = res7.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + return Either.left(new Error(s, "unrecognized value")); + } - Either> res6 = bool(s); - if (res6.isRight()) { - Tuple2 t = res6.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + public static Either> factTerm(String s) { + if (s.length() > 0 && s.charAt(0) == '$') { + return Either.left(new Error(s, "variables are not allowed in facts")); + } - Either> res4 = date(s); - if (res4.isRight()) { - Tuple2 t = res4.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res2 = string(s); + if (res2.isRight()) { + Tuple2 t = res2.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res3 = integer(s); - if (res3.isRight()) { - Tuple2 t = res3.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res7 = set(s); + if (res7.isRight()) { + Tuple2 t = res7.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - Either> res8 = bytes(s); - if (res8.isRight()) { - Tuple2 t = res8.get(); - return Either.right(new Tuple2<>(t._1, t._2)); - } + Either> res6 = bool(s); + if (res6.isRight()) { + Tuple2 t = res6.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - return Either.left(new Error(s, "unrecognized value")); + Either> res4 = date(s); + if (res4.isRight()) { + Tuple2 t = res4.get(); + return Either.right(new Tuple2<>(t._1, t._2)); } - public static Either> string(String s) { - if (s.charAt(0) != '"') { - return Either.left(new Error(s, "not a string")); - } + Either> res3 = integer(s); + if (res3.isRight()) { + Tuple2 t = res3.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - int index = s.length(); - for (int i = 1; i < s.length(); i++) { - char c = s.charAt(i); + Either> res8 = bytes(s); + if (res8.isRight()) { + Tuple2 t = res8.get(); + return Either.right(new Tuple2<>(t._1, t._2)); + } - if (c == '\\' && s.charAt(i + 1) == '"') { - i += 1; - continue; - } + return Either.left(new Error(s, "unrecognized value")); + } - if (c == '"') { - index = i - 1; - break; - } - } + public static Either> string(String s) { + if (s.charAt(0) != '"') { + return Either.left(new Error(s, "not a string")); + } - if (index == s.length()) { - return Either.left(new Error(s, "end of string not found")); - } + int index = s.length(); + for (int i = 1; i < s.length(); i++) { + char c = s.charAt(i); - if (s.charAt(index + 1) != '"') { - return Either.left(new Error(s, "ending double quote not found")); - } + if (c == '\\' && s.charAt(i + 1) == '"') { + i += 1; + continue; + } - String string = s.substring(1, index + 1); - String remaining = s.substring(index + 2); - - return Either.right(new Tuple2(remaining, (Term.Str) Utils.string(string))); + if (c == '"') { + index = i - 1; + break; + } } - public static Either> integer(String s) { - int index = 0; - if (s.charAt(0) == '-') { - index += 1; - } - - int index2 = s.length(); - for (int i = index; i < s.length(); i++) { - char c = s.charAt(i); + if (index == s.length()) { + return Either.left(new Error(s, "end of string not found")); + } - if (!Character.isDigit(c)) { - index2 = i; - break; - } - } + if (s.charAt(index + 1) != '"') { + return Either.left(new Error(s, "ending double quote not found")); + } - if (index2 == 0) { - return Either.left(new Error(s, "not an integer")); - } + String string = s.substring(1, index + 1); + String remaining = s.substring(index + 2); - long i = Long.parseLong(s.substring(0, index2)); - String remaining = s.substring(index2); + return Either.right(new Tuple2(remaining, (Term.Str) Utils.string(string))); + } - return Either.right(new Tuple2(remaining, (Term.Integer) Utils.integer(i))); + public static Either> integer(String s) { + int index = 0; + if (s.charAt(0) == '-') { + index += 1; } - public static Either> date(String s) { - Tuple2 t = take_while(s, (c) -> c != ' ' && c != ',' && c != ')' && c != ']'); + int index2 = s.length(); + for (int i = index; i < s.length(); i++) { + char c = s.charAt(i); - try { - OffsetDateTime d = OffsetDateTime.parse(t._1); - String remaining = t._2; - return Either.right(new Tuple2(remaining, new Term.Date(d.toEpochSecond()))); - } catch (DateTimeParseException e) { - return Either.left(new Error(s, "not a date")); + if (!Character.isDigit(c)) { + index2 = i; + break; + } + } - } + if (index2 == 0) { + return Either.left(new Error(s, "not an integer")); } - public static Either> variable(String s) { - if (s.charAt(0) != '$') { - return Either.left(new Error(s, "not a variable")); - } + long i = Long.parseLong(s.substring(0, index2)); + String remaining = s.substring(index2); - Tuple2 t = take_while(s.substring(1), (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_'); + return Either.right( + new Tuple2(remaining, (Term.Integer) Utils.integer(i))); + } - return Either.right(new Tuple2(t._2, (Term.Variable) Utils.var(t._1))); - } + public static Either> date(String s) { + Tuple2 t = takewhile(s, (c) -> c != ' ' && c != ',' && c != ')' && c != ']'); - public static Either> bool(String s) { - boolean b; - if (s.startsWith("true")) { - b = true; - s = s.substring(4); - } else if (s.startsWith("false")) { - b = false; - s = s.substring(5); - } else { - return Either.left(new Error(s, "not a boolean")); - } - - return Either.right(new Tuple2<>(s, new Term.Bool(b))); + try { + OffsetDateTime d = OffsetDateTime.parse(t._1); + String remaining = t._2; + return Either.right( + new Tuple2(remaining, new Term.Date(d.toEpochSecond()))); + } catch (DateTimeParseException e) { + return Either.left(new Error(s, "not a date")); } + } - public static Either> set(String s) { - if (s.length() == 0 || s.charAt(0) != '[') { - return Either.left(new Error(s, "not a set")); - } + public static Either> variable(String s) { + if (s.charAt(0) != '$') { + return Either.left(new Error(s, "not a variable")); + } - s = s.substring(1); + Tuple2 t = + takewhile( + s.substring(1), (c) -> Character.isAlphabetic(c) || Character.isDigit(c) || c == '_'); + + return Either.right(new Tuple2(t._2, (Term.Variable) Utils.var(t._1))); + } + + public static Either> bool(String s) { + boolean b; + if (s.startsWith("true")) { + b = true; + s = s.substring(4); + } else if (s.startsWith("false")) { + b = false; + s = s.substring(5); + } else { + return Either.left(new Error(s, "not a boolean")); + } - HashSet terms = new HashSet(); - while (true) { + return Either.right(new Tuple2<>(s, new Term.Bool(b))); + } - s = space(s); + public static Either> set(String s) { + if (s.length() == 0 || s.charAt(0) != '[') { + return Either.left(new Error(s, "not a set")); + } - Either> res = fact_term(s); - if (res.isLeft()) { - break; - } + s = s.substring(1); - Tuple2 t = res.get(); + HashSet terms = new HashSet(); + while (true) { - if (t._2 instanceof Term.Variable) { - return Either.left(new Error(s, "sets cannot contain variables")); - } + s = space(s); - s = t._1; - terms.add(t._2); + Either> res = factTerm(s); + if (res.isLeft()) { + break; + } - s = space(s); + Tuple2 t = res.get(); - if (s.length() == 0 || s.charAt(0) != ',') { - break; - } else { - s = s.substring(1); - } - } + if (t._2 instanceof Term.Variable) { + return Either.left(new Error(s, "sets cannot contain variables")); + } - s = space(s); - if (0 == s.length() || s.charAt(0) != ']') { - return Either.left(new Error(s, "closing square bracket not found")); - } + s = t._1; + terms.add(t._2); - String remaining = s.substring(1); + s = space(s); - return Either.right(new Tuple2<>(remaining, new Term.Set(terms))); + if (s.length() == 0 || s.charAt(0) != ',') { + break; + } else { + s = s.substring(1); + } } - public static Either> bytes(String s) { - if (!s.startsWith("hex:")) { - return Either.left(new Error(s, "not a bytes array")); - } - s = s.substring(4); - Tuple2 t = hex(s); - return Either.right(new Tuple2<>(t._1, new Term.Bytes(t._2))); + s = space(s); + if (0 == s.length() || s.charAt(0) != ']') { + return Either.left(new Error(s, "closing square bracket not found")); } - public static Tuple2 hex(String s) { - int index = 0; - for (int i = 0; i < s.length(); i++) { - char c = s.charAt(i); - if("0123456789ABCDEFabcdef".indexOf(c) == -1) { - break; - } - - index += 1; - } + String remaining = s.substring(1); - String hex = s.substring(0, index); - byte[] bytes = Utils.hexStringToByteArray(hex); - s = s.substring(index); - return new Tuple2<>(s,bytes); + return Either.right(new Tuple2<>(remaining, new Term.Set(terms))); + } + public static Either> bytes(String s) { + if (!s.startsWith("hex:")) { + return Either.left(new Error(s, "not a bytes array")); } - - public static Either> expression(String s) { - return ExpressionParser.parse(s); + s = s.substring(4); + Tuple2 t = hex(s); + return Either.right(new Tuple2<>(t._1, new Term.Bytes(t._2))); + } + + public static Tuple2 hex(String s) { + int index = 0; + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if ("0123456789ABCDEFabcdef".indexOf(c) == -1) { + break; + } + + index += 1; } - public static String space(String s) { - int index = 0; - for (int i = 0; i < s.length(); i++) { - char c = s.charAt(i); - - if (c != ' ' && c != '\t' && c != '\r' && c != '\n') { - break; - } - index += 1; - } - - return s.substring(index); + String hex = s.substring(0, index); + byte[] bytes = Utils.hexStringToByteArray(hex); + s = s.substring(index); + return new Tuple2<>(s, bytes); + } + + public static Either> expression(String s) { + return ExpressionParser.parse(s); + } + + public static String space(String s) { + int index = 0; + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + + if (c != ' ' && c != '\t' && c != '\r' && c != '\n') { + break; + } + index += 1; } - public static Tuple2 take_while(String s, Function f) { - int index = s.length(); - for (int i = 0; i < s.length(); i++) { - Character c = s.charAt(i); + return s.substring(index); + } - if (!f.apply(c)) { - index = i; - break; - } - } + public static Tuple2 takewhile(String s, Function f) { + int index = s.length(); + for (int i = 0; i < s.length(); i++) { + Character c = s.charAt(i); - return new Tuple2<>(s.substring(0, index), s.substring(index)); - } - - public static String removeCommentsAndWhitespaces(String s) { - s = removeComments(s); - s = s.replace("\n", "").replace("\\\"", "\"").strip(); - return s; - } - - public static String removeComments(String str) { - StringBuilder result = new StringBuilder(); - String remaining = str; - - while (!remaining.isEmpty()) { - remaining = space(remaining); // Skip leading whitespace - if (remaining.startsWith("/*")) { - // Find the end of the multiline comment - remaining = remaining.substring(2); // Skip "/*" - String finalRemaining = remaining; - Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("*/")); - remaining = split._2.length() > 2 ? split._2.substring(2) : ""; // Skip "*/" - } else if (remaining.startsWith("//")) { - // Find the end of the single-line comment - remaining = remaining.substring(2); // Skip "//" - Tuple2 split = take_while(remaining, c -> c != '\n' && c != '\r'); - remaining = split._2; - if (!remaining.isEmpty()) { - result.append(remaining.charAt(0)); // Preserve line break - remaining = remaining.substring(1); - } - } else { - // Take non-comment text until the next comment or end of string - String finalRemaining = remaining; - Tuple2 split = take_while(remaining, c -> !finalRemaining.startsWith("/*") && !finalRemaining.startsWith("//")); - result.append(split._1); - remaining = split._2; - } - } + if (!f.apply(c)) { + index = i; + break; + } + } - return result.toString(); + return new Tuple2<>(s.substring(0, index), s.substring(index)); + } + + public static String removeCommentsAndWhitespaces(String s) { + s = removeComments(s); + s = s.replace("\n", "").replace("\\\"", "\"").strip(); + return s; + } + + public static String removeComments(String str) { + StringBuilder result = new StringBuilder(); + String remaining = str; + + while (!remaining.isEmpty()) { + remaining = space(remaining); // Skip leading whitespace + if (remaining.startsWith("/*")) { + // Find the end of the multiline comment + remaining = remaining.substring(2); // Skip "/*" + String finalRemaining = remaining; + Tuple2 split = takewhile(remaining, c -> !finalRemaining.startsWith("*/")); + remaining = split._2.length() > 2 ? split._2.substring(2) : ""; // Skip "*/" + } else if (remaining.startsWith("//")) { + // Find the end of the single-line comment + remaining = remaining.substring(2); // Skip "//" + Tuple2 split = takewhile(remaining, c -> c != '\n' && c != '\r'); + remaining = split._2; + if (!remaining.isEmpty()) { + result.append(remaining.charAt(0)); // Preserve line break + remaining = remaining.substring(1); + } + } else { + // Take non-comment text until the next comment or end of string + String finalRemaining = remaining; + Tuple2 split = + takewhile( + remaining, + c -> !finalRemaining.startsWith("/*") && !finalRemaining.startsWith("//")); + result.append(split._1); + remaining = split._2; + } } + + return result.toString(); + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/format/ExternalSignature.java b/src/main/java/org/biscuitsec/biscuit/token/format/ExternalSignature.java index 1248d3aa..ac8fb085 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/format/ExternalSignature.java +++ b/src/main/java/org/biscuitsec/biscuit/token/format/ExternalSignature.java @@ -3,11 +3,19 @@ import org.biscuitsec.biscuit.crypto.PublicKey; public class ExternalSignature { - public PublicKey key; - public byte[] signature; + private final PublicKey key; + private final byte[] signature; - public ExternalSignature(PublicKey key, byte[] signature) { - this.key = key; - this.signature = signature; - } + public ExternalSignature(PublicKey key, byte[] signature) { + this.key = key; + this.signature = signature; + } + + public PublicKey getKey() { + return key; + } + + public byte[] getSignature() { + return signature; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/format/Proof.java b/src/main/java/org/biscuitsec/biscuit/token/format/Proof.java index b4a0ef82..3f53b62e 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/format/Proof.java +++ b/src/main/java/org/biscuitsec/biscuit/token/format/Proof.java @@ -1,24 +1,86 @@ package org.biscuitsec.biscuit.token.format; -import org.biscuitsec.biscuit.crypto.KeyPair; import io.vavr.control.Option; +import org.biscuitsec.biscuit.crypto.KeyPair; + +/** Sum type for Proof NextSecret or FinalSignature. */ +interface Proof { + /** + * Test if the proof is sealed. + * + * @return true if sealed and is FinalSignature + */ + boolean isSealed(); + + /** + * Get the KeyPair of the proof if the proof is not sealed. + * + * @return secret keypair if not sealed + */ + KeyPair secretKey(); + + /** + * Get the signature in case of sealed proof. + * + * @return the signature if sealed or none + */ + Option getSignature(); + + /** NextSecret with a keypair. */ + final class NextSecret implements Proof { + /** the secret keypair for the block. */ + private final KeyPair secretKey; + + /** + * Create a new NextSecret. + * + * @param secretKey the associated keypair + */ + NextSecret(final KeyPair secretKey) { + this.secretKey = secretKey; + } + + @Override + public KeyPair secretKey() { + return this.secretKey; + } + + @Override + public boolean isSealed() { + return false; + } + + @Override + public Option getSignature() { + return Option.none(); + } + } -public class Proof { - public Option secretKey; - public Option signature; + final class FinalSignature implements Proof { + /** final signature. */ + private final byte[] signature; + + FinalSignature(final byte[] signature) { + this.signature = signature; + } + + public byte[] signature() { + return this.signature; + } - public Proof(Option secretKey, Option signature) { - this.secretKey = secretKey; - this.signature = signature; + @Override + public KeyPair secretKey() { + throw new RuntimeException("Sealed Block no keypair available"); } - public Proof(KeyPair secretKey) { - this.secretKey = Option.some(secretKey); - this.signature = Option.none(); + @Override + public boolean isSealed() { + return true; } - public Proof(byte[] signature) { - this.secretKey = Option.none(); - this.signature = Option.some(signature); + @Override + public Option getSignature() { + return Option.some(this.signature); } + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/format/SerializedBiscuit.java b/src/main/java/org/biscuitsec/biscuit/token/format/SerializedBiscuit.java index e0974260..78d2c502 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/format/SerializedBiscuit.java +++ b/src/main/java/org/biscuitsec/biscuit/token/format/SerializedBiscuit.java @@ -1,483 +1,531 @@ package org.biscuitsec.biscuit.token.format; +import static io.vavr.API.Left; +import static io.vavr.API.Right; + import biscuit.format.schema.Schema; -import io.vavr.Tuple2; -import org.biscuitsec.biscuit.crypto.BlockSignatureBuffer; -import org.biscuitsec.biscuit.crypto.KeyDelegate; -import org.biscuitsec.biscuit.crypto.KeyPair; -import org.biscuitsec.biscuit.crypto.PublicKey; -import org.biscuitsec.biscuit.datalog.SymbolTable; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.token.Block; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; +import io.vavr.Tuple2; import io.vavr.control.Either; import io.vavr.control.Option; - import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.security.*; -import java.util.*; - -import static io.vavr.API.Left; -import static io.vavr.API.Right; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.Signature; +import java.security.SignatureException; +import java.util.ArrayList; +import java.util.List; +import org.biscuitsec.biscuit.crypto.BlockSignatureBuffer; +import org.biscuitsec.biscuit.crypto.KeyDelegate; +import org.biscuitsec.biscuit.crypto.KeyPair; +import org.biscuitsec.biscuit.crypto.PublicKey; +import org.biscuitsec.biscuit.datalog.SymbolTable; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.token.Block; -/** - * Intermediate representation of a token before full serialization - */ -public class SerializedBiscuit { - public SignedBlock authority; - public List blocks; - public Proof proof; - public Option root_key_id; - - public static int MIN_SCHEMA_VERSION = 3; - public static int MAX_SCHEMA_VERSION = 5; - - /** - * Deserializes a SerializedBiscuit from a byte array - * - * @param slice - * @return - */ - static public SerializedBiscuit from_bytes(byte[] slice, org.biscuitsec.biscuit.crypto.PublicKey root) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - try { - Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); - - return from_bytes_inner(data, root); - } catch (InvalidProtocolBufferException e) { - throw new Error.FormatError.DeserializationError(e.toString()); - } +/** Intermediate representation of a token before full serialization */ +public final class SerializedBiscuit { + private final SignedBlock authority; + private final List blocks; + private Proof proof; + private Option rootKeyId; + + public static final int MIN_SCHEMA_VERSION = 3; + public static final int MAX_SCHEMA_VERSION = 5; + + /** + * Deserializes a SerializedBiscuit from a byte array + * + * @param slice + * @return + */ + public static SerializedBiscuit fromBytes( + byte[] slice, org.biscuitsec.biscuit.crypto.PublicKey root) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + try { + Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); + + return fromBytesInner(data, root); + } catch (InvalidProtocolBufferException e) { + throw new Error.FormatError.DeserializationError(e.toString()); } - - /** - * Deserializes a SerializedBiscuit from a byte array - * - * @param slice - * @return - */ - static public SerializedBiscuit from_bytes(byte[] slice, KeyDelegate delegate) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - try { - Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); - - Option root_key_id = Option.none(); - if (data.hasRootKeyId()) { - root_key_id = Option.some(data.getRootKeyId()); - } - - Option root = delegate.root_key(root_key_id); - if (root.isEmpty()) { - throw new InvalidKeyException("unknown root key id"); - } - - return from_bytes_inner(data, root.get()); - } catch (InvalidProtocolBufferException e) { - throw new Error.FormatError.DeserializationError(e.toString()); - } + } + + /** + * Deserializes a SerializedBiscuit from a byte array + * + * @param slice + * @return + */ + public static SerializedBiscuit fromBytes(byte[] slice, KeyDelegate delegate) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + try { + Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); + + Option rootKeyId = Option.none(); + if (data.hasRootKeyId()) { + rootKeyId = Option.some(data.getRootKeyId()); + } + + Option root = delegate.getRootKey(rootKeyId); + if (root.isEmpty()) { + throw new InvalidKeyException("unknown root key id"); + } + + return fromBytesInner(data, root.get()); + } catch (InvalidProtocolBufferException e) { + throw new Error.FormatError.DeserializationError(e.toString()); } - - static SerializedBiscuit from_bytes_inner(Schema.Biscuit data, org.biscuitsec.biscuit.crypto.PublicKey root) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - SerializedBiscuit b = SerializedBiscuit.deserialize(data); - if (data.hasRootKeyId()) { - b.root_key_id = Option.some(data.getRootKeyId()); - } - - Either res = b.verify(root); - if (res.isLeft()) { - throw res.getLeft(); - } else { - return b; - } - + } + + static SerializedBiscuit fromBytesInner( + Schema.Biscuit data, org.biscuitsec.biscuit.crypto.PublicKey root) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + SerializedBiscuit b = SerializedBiscuit.deserialize(data); + if (data.hasRootKeyId()) { + b.rootKeyId = Option.some(data.getRootKeyId()); } - /** - * Warning: this deserializes without verifying the signature - * - * @param slice - * @return SerializedBiscuit - * @throws Error.FormatError.DeserializationError - */ - static public SerializedBiscuit unsafe_deserialize(byte[] slice) throws Error.FormatError.DeserializationError { - try { - Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); - return SerializedBiscuit.deserialize(data); - } catch (InvalidProtocolBufferException e) { - throw new Error.FormatError.DeserializationError(e.toString()); - } + Either res = b.verify(root); + if (res.isLeft()) { + throw res.getLeft(); + } else { + return b; } - - /** - * Warning: this deserializes without verifying the signature - * - * @param data - * @return SerializedBiscuit - * @throws Error.FormatError.DeserializationError - */ - static private SerializedBiscuit deserialize(Schema.Biscuit data) throws Error.FormatError.DeserializationError { - if(data.getAuthority().hasExternalSignature()) { - throw new Error.FormatError.DeserializationError("the authority block must not contain an external signature"); - } - - SignedBlock authority = new SignedBlock( - data.getAuthority().getBlock().toByteArray(), - org.biscuitsec.biscuit.crypto.PublicKey.deserialize(data.getAuthority().getNextKey()), - data.getAuthority().getSignature().toByteArray(), - Option.none() - ); - - ArrayList blocks = new ArrayList<>(); - for (Schema.SignedBlock block : data.getBlocksList()) { - Option external = Option.none(); - if(block.hasExternalSignature() && block.getExternalSignature().hasPublicKey() - && block.getExternalSignature().hasSignature()) { - Schema.ExternalSignature ex = block.getExternalSignature(); - external = Option.some(new ExternalSignature( - org.biscuitsec.biscuit.crypto.PublicKey.deserialize(ex.getPublicKey()), - ex.getSignature().toByteArray())); - - } - blocks.add(new SignedBlock( - block.getBlock().toByteArray(), - org.biscuitsec.biscuit.crypto.PublicKey.deserialize(block.getNextKey()), - block.getSignature().toByteArray(), - external - )); - } - - Option secretKey = Option.none(); - if (data.getProof().hasNextSecret()) { - secretKey = Option.some(KeyPair.generate(authority.key.algorithm, data.getProof().getNextSecret().toByteArray())); - } - - Option signature = Option.none(); - if (data.getProof().hasFinalSignature()) { - signature = Option.some(data.getProof().getFinalSignature().toByteArray()); - } - - if (secretKey.isEmpty() && signature.isEmpty()) { - throw new Error.FormatError.DeserializationError("empty proof"); - } - Proof proof = new Proof(secretKey, signature); - - Option rootKeyId = data.hasRootKeyId() ? Option.some(data.getRootKeyId()) : Option.none(); - - return new SerializedBiscuit(authority, blocks, proof, rootKeyId); + } + + /** + * Warning: this deserializes without verifying the signature + * + * @param slice + * @return SerializedBiscuit + * @throws Error.FormatError.DeserializationError + */ + public static SerializedBiscuit deserializeUnsafe(byte[] slice) + throws Error.FormatError.DeserializationError { + try { + Schema.Biscuit data = Schema.Biscuit.parseFrom(slice); + return SerializedBiscuit.deserialize(data); + } catch (InvalidProtocolBufferException e) { + throw new Error.FormatError.DeserializationError(e.toString()); + } + } + + /** + * Warning: this deserializes without verifying the signature + * + * @param data + * @return SerializedBiscuit + * @throws Error.FormatError.DeserializationError + */ + private static SerializedBiscuit deserialize(Schema.Biscuit data) + throws Error.FormatError.DeserializationError { + if (data.getAuthority().hasExternalSignature()) { + throw new Error.FormatError.DeserializationError( + "the authority block must not contain an external signature"); } + SignedBlock authority = + new SignedBlock( + data.getAuthority().getBlock().toByteArray(), + org.biscuitsec.biscuit.crypto.PublicKey.deserialize(data.getAuthority().getNextKey()), + data.getAuthority().getSignature().toByteArray(), + Option.none()); + + ArrayList blocks = new ArrayList<>(); + for (Schema.SignedBlock block : data.getBlocksList()) { + Option external = Option.none(); + if (block.hasExternalSignature() + && block.getExternalSignature().hasPublicKey() + && block.getExternalSignature().hasSignature()) { + Schema.ExternalSignature ex = block.getExternalSignature(); + external = + Option.some( + new ExternalSignature( + org.biscuitsec.biscuit.crypto.PublicKey.deserialize(ex.getPublicKey()), + ex.getSignature().toByteArray())); + } + blocks.add( + new SignedBlock( + block.getBlock().toByteArray(), + org.biscuitsec.biscuit.crypto.PublicKey.deserialize(block.getNextKey()), + block.getSignature().toByteArray(), + external)); + } - /** - * Serializes a SerializedBiscuit to a byte array - * - * @return - */ - public byte[] serialize() throws Error.FormatError.SerializationError { - Schema.Biscuit.Builder biscuitBuilder = Schema.Biscuit.newBuilder(); - Schema.SignedBlock.Builder authorityBuilder = Schema.SignedBlock.newBuilder(); - { - SignedBlock block = this.authority; - authorityBuilder.setBlock(ByteString.copyFrom(block.block)); - authorityBuilder.setNextKey(block.key.serialize()); - authorityBuilder.setSignature(ByteString.copyFrom(block.signature)); - } - biscuitBuilder.setAuthority(authorityBuilder.build()); - - for (SignedBlock block : this.blocks) { - Schema.SignedBlock.Builder blockBuilder = Schema.SignedBlock.newBuilder(); - blockBuilder.setBlock(ByteString.copyFrom(block.block)); - blockBuilder.setNextKey(block.key.serialize()); - blockBuilder.setSignature(ByteString.copyFrom(block.signature)); - - if (block.externalSignature.isDefined()) { - ExternalSignature externalSignature = block.externalSignature.get(); - Schema.ExternalSignature.Builder externalSignatureBuilder = Schema.ExternalSignature.newBuilder(); - externalSignatureBuilder.setPublicKey(externalSignature.key.serialize()); - externalSignatureBuilder.setSignature(ByteString.copyFrom(externalSignature.signature)); - blockBuilder.setExternalSignature(externalSignatureBuilder.build()); - } - - biscuitBuilder.addBlocks(blockBuilder.build()); - } - - Schema.Proof.Builder proofBuilder = Schema.Proof.newBuilder(); - if (!this.proof.secretKey.isEmpty()) { - proofBuilder.setNextSecret(ByteString.copyFrom(this.proof.secretKey.get().toBytes())); - } else { - proofBuilder.setFinalSignature(ByteString.copyFrom(this.proof.signature.get())); - } - - biscuitBuilder.setProof(proofBuilder.build()); - if (!this.root_key_id.isEmpty()) { - biscuitBuilder.setRootKeyId(this.root_key_id.get()); - } + if (!(data.getProof().hasNextSecret() ^ data.getProof().hasFinalSignature())) { + throw new Error.FormatError.DeserializationError("empty proof"); + } - Schema.Biscuit biscuit = biscuitBuilder.build(); + final Proof proof = + data.getProof().hasFinalSignature() + ? new Proof.FinalSignature(data.getProof().getFinalSignature().toByteArray()) + : new Proof.NextSecret( + KeyPair.generate( + authority.getKey().getAlgorithm(), + data.getProof().getNextSecret().toByteArray())); + + Option rootKeyId = + data.hasRootKeyId() ? Option.some(data.getRootKeyId()) : Option.none(); + + return new SerializedBiscuit(authority, blocks, proof, rootKeyId); + } + + /** + * Serializes a SerializedBiscuit to a byte array + * + * @return + */ + public byte[] serialize() throws Error.FormatError.SerializationError { + Schema.SignedBlock.Builder authorityBuilder = Schema.SignedBlock.newBuilder(); + SignedBlock authorityBlock = this.authority; + authorityBuilder.setBlock(ByteString.copyFrom(authorityBlock.getBlock())); + authorityBuilder.setNextKey(authorityBlock.getKey().serialize()); + authorityBuilder.setSignature(ByteString.copyFrom(authorityBlock.getSignature())); + Schema.Biscuit.Builder biscuitBuilder = Schema.Biscuit.newBuilder(); + biscuitBuilder.setAuthority(authorityBuilder.build()); + + for (SignedBlock b : this.blocks) { + Schema.SignedBlock.Builder blockBuilder = Schema.SignedBlock.newBuilder(); + blockBuilder.setBlock(ByteString.copyFrom(b.getBlock())); + blockBuilder.setNextKey(b.getKey().serialize()); + blockBuilder.setSignature(ByteString.copyFrom(b.getSignature())); + + if (b.getExternalSignature().isDefined()) { + ExternalSignature externalSignature = b.getExternalSignature().get(); + Schema.ExternalSignature.Builder externalSignatureBuilder = + Schema.ExternalSignature.newBuilder(); + externalSignatureBuilder.setPublicKey(externalSignature.getKey().serialize()); + externalSignatureBuilder.setSignature( + ByteString.copyFrom(externalSignature.getSignature())); + blockBuilder.setExternalSignature(externalSignatureBuilder.build()); + } + + biscuitBuilder.addBlocks(blockBuilder.build()); + } - try { - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - biscuit.writeTo(stream); - return stream.toByteArray(); - } catch (IOException e) { - throw new Error.FormatError.SerializationError(e.toString()); - } + Schema.Proof.Builder proofBuilder = Schema.Proof.newBuilder(); + if (this.proof.isSealed()) { + Proof.FinalSignature finalSignature = (Proof.FinalSignature) this.proof; + proofBuilder.setFinalSignature(ByteString.copyFrom(finalSignature.signature())); + } else { + Proof.NextSecret nextSecret = (Proof.NextSecret) this.proof; + proofBuilder.setNextSecret(ByteString.copyFrom(nextSecret.secretKey().toBytes())); + } + biscuitBuilder.setProof(proofBuilder.build()); + if (!this.rootKeyId.isEmpty()) { + biscuitBuilder.setRootKeyId(this.rootKeyId.get()); } - static public Either make(final org.biscuitsec.biscuit.crypto.KeyPair root, - final Block authority, final org.biscuitsec.biscuit.crypto.KeyPair next) { + Schema.Biscuit biscuit = biscuitBuilder.build(); - return make(root, Option.none(), authority, next); + try { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + biscuit.writeTo(stream); + return stream.toByteArray(); + } catch (IOException e) { + throw new Error.FormatError.SerializationError(e.toString()); } - - static public Either make(final org.biscuitsec.biscuit.crypto.Signer rootSigner, final Option root_key_id, - final Block authority, final org.biscuitsec.biscuit.crypto.KeyPair next) { - Schema.Block b = authority.serialize(); - try { - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - b.writeTo(stream); - byte[] block = stream.toByteArray(); - org.biscuitsec.biscuit.crypto.PublicKey next_key = next.public_key(); - byte[] payload = BlockSignatureBuffer.getBufferSignature(next_key, block); - byte[] signature = rootSigner.sign(payload); - SignedBlock signedBlock = new SignedBlock(block, next_key, signature, Option.none()); - Proof proof = new Proof(next); - - return Right(new SerializedBiscuit(signedBlock, new ArrayList<>(), proof, root_key_id)); - } catch (IOException | NoSuchAlgorithmException | SignatureException | InvalidKeyException e) { - return Left(new Error.FormatError.SerializationError(e.toString())); - } + } + + public static Either make( + final org.biscuitsec.biscuit.crypto.KeyPair root, + final Block authority, + final org.biscuitsec.biscuit.crypto.KeyPair next) { + + return make(root, Option.none(), authority, next); + } + + public static Either make( + final org.biscuitsec.biscuit.crypto.Signer rootSigner, + final Option rootKeyId, + final Block authority, + final org.biscuitsec.biscuit.crypto.KeyPair next) { + Schema.Block b = authority.serialize(); + try { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + b.writeTo(stream); + byte[] block = stream.toByteArray(); + org.biscuitsec.biscuit.crypto.PublicKey nextKey = next.getPublicKey(); + byte[] payload = BlockSignatureBuffer.getBufferSignature(nextKey, block); + byte[] signature = rootSigner.sign(payload); + SignedBlock signedBlock = new SignedBlock(block, nextKey, signature, Option.none()); + Proof proof = new Proof.NextSecret(next); + + return Right(new SerializedBiscuit(signedBlock, new ArrayList<>(), proof, rootKeyId)); + } catch (IOException | NoSuchAlgorithmException | SignatureException | InvalidKeyException e) { + return Left(new Error.FormatError.SerializationError(e.toString())); + } + } + + public Either append( + final org.biscuitsec.biscuit.crypto.KeyPair next, + final Block newBlock, + Option externalSignature) { + if (this.proof.isSealed()) { + return Left(new Error.FormatError.SerializationError("the token is sealed")); } - public Either append(final org.biscuitsec.biscuit.crypto.KeyPair next, - final Block newBlock, Option externalSignature) { - if (this.proof.secretKey.isEmpty()) { - return Left(new Error.FormatError.SerializationError("the token is sealed")); - } + Schema.Block b = newBlock.serialize(); + try { + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + b.writeTo(stream); - Schema.Block b = newBlock.serialize(); - try { - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - b.writeTo(stream); + byte[] block = stream.toByteArray(); + KeyPair secretKey = this.proof.secretKey(); + org.biscuitsec.biscuit.crypto.PublicKey nextKey = next.getPublicKey(); - byte[] block = stream.toByteArray(); - KeyPair secretKey = this.proof.secretKey.get(); - org.biscuitsec.biscuit.crypto.PublicKey next_key = next.public_key(); + byte[] payload = + BlockSignatureBuffer.getBufferSignature( + nextKey, block, externalSignature.toJavaOptional()); + byte[] signature = this.proof.secretKey().sign(payload); - byte[] payload = BlockSignatureBuffer.getBufferSignature(next_key, block, externalSignature.toJavaOptional()); - byte[] signature = this.proof.secretKey.get().sign(payload); + SignedBlock signedBlock = new SignedBlock(block, nextKey, signature, externalSignature); - SignedBlock signedBlock = new SignedBlock(block, next_key, signature, externalSignature); + ArrayList blocks = new ArrayList<>(); + for (SignedBlock bl : this.blocks) { + blocks.add(bl); + } + blocks.add(signedBlock); - ArrayList blocks = new ArrayList<>(); - for (SignedBlock bl : this.blocks) { - blocks.add(bl); - } - blocks.add(signedBlock); + Proof proof = new Proof.NextSecret(next); - Proof proof = new Proof(next); + return Right(new SerializedBiscuit(this.authority, blocks, proof, rootKeyId)); + } catch (IOException | NoSuchAlgorithmException | SignatureException | InvalidKeyException e) { + return Left(new Error.FormatError.SerializationError(e.toString())); + } + } + + public Either verify(org.biscuitsec.biscuit.crypto.PublicKey root) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { + org.biscuitsec.biscuit.crypto.PublicKey currentKey = root; + Either res = + verifyBlockSignature(this.authority, currentKey); + if (res.isRight()) { + currentKey = res.get(); + } else { + return Left(res.getLeft()); + } - return Right(new SerializedBiscuit(this.authority, blocks, proof, root_key_id)); - } catch (IOException | NoSuchAlgorithmException | SignatureException | InvalidKeyException e) { - return Left(new Error.FormatError.SerializationError(e.toString())); - } + for (SignedBlock b : this.blocks) { + res = verifyBlockSignature(b, currentKey); + if (res.isRight()) { + currentKey = res.get(); + } else { + return Left(res.getLeft()); + } } - public Either verify(org.biscuitsec.biscuit.crypto.PublicKey root) throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - org.biscuitsec.biscuit.crypto.PublicKey current_key = root; - { - Either res = verifyBlockSignature(this.authority, current_key); - if(res.isRight()) { - current_key = res.get(); - } else { - return Left(res.getLeft()); - } - } + // System.out.println("signatures verified, checking proof"); - for (SignedBlock b : this.blocks) { - Either res = verifyBlockSignature(b, current_key); - if(res.isRight()) { - current_key = res.get(); - } else { - return Left(res.getLeft()); - } - } + if (!this.proof.isSealed()) { + // System.out.println("checking secret key"); + // System.out.println("current key: " + currentKey.toHex()); + // System.out.println("key from proof: " + this.proof.secretKey.get().public_key().toHex()); + if (this.proof.secretKey().getPublicKey().equals(currentKey)) { + // System.out.println("public keys are equal"); - //System.out.println("signatures verified, checking proof"); + return Right(null); + } else { + // System.out.println("public keys are not equal"); - if (!this.proof.secretKey.isEmpty()) { - //System.out.println("checking secret key"); - //System.out.println("current key: "+current_key.toHex()); - //System.out.println("key from proof: "+this.proof.secretKey.get().public_key().toHex()); - if (this.proof.secretKey.get().public_key().equals(current_key)) { - //System.out.println("public keys are equal"); + return Left( + new Error.FormatError.Signature.InvalidSignature( + "signature error: Verification equation was not satisfied")); + } + } else { + // System.out.println("checking final signature"); - return Right(null); - } else { - //System.out.println("public keys are not equal"); + byte[] finalSignature = this.proof.getSignature().get(); - return Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")); - } - } else { - //System.out.println("checking final signature"); + SignedBlock b; + if (this.blocks.isEmpty()) { + b = this.authority; + } else { + b = this.blocks.get(this.blocks.size() - 1); + } - byte[] finalSignature = this.proof.signature.get(); + byte[] block = b.getBlock(); + org.biscuitsec.biscuit.crypto.PublicKey nextKey = b.getKey(); + byte[] signature = b.getSignature(); - SignedBlock b; - if (this.blocks.isEmpty()) { - b = this.authority; - } else { - b = this.blocks.get(this.blocks.size() - 1); - } + byte[] payload = BlockSignatureBuffer.getBufferSealedSignature(nextKey, block, signature); - byte[] block = b.block; - org.biscuitsec.biscuit.crypto.PublicKey next_key = b.key; - byte[] signature = b.signature; + if (KeyPair.verify(currentKey, payload, finalSignature)) { + return Right(null); + } else { + return Left(new Error.FormatError.Signature.SealedSignature()); + } + } + } - byte[] payload = BlockSignatureBuffer.getBufferSealedSignature(next_key, block, signature); + static Either verifyBlockSignature( + SignedBlock signedBlock, org.biscuitsec.biscuit.crypto.PublicKey publicKey) + throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - if (KeyPair.verify(current_key, payload, finalSignature)) { - return Right(null); - } else { - return Left(new Error.FormatError.Signature.SealedSignature()); - } + org.biscuitsec.biscuit.crypto.PublicKey nextKey = signedBlock.getKey(); + byte[] signature = signedBlock.getSignature(); - } + var signatureLengthError = + PublicKey.validateSignatureLength(publicKey.getAlgorithm(), signature.length); + if (signatureLengthError.isPresent()) { + return Left(signatureLengthError.get()); } - static Either verifyBlockSignature(SignedBlock signedBlock, org.biscuitsec.biscuit.crypto.PublicKey publicKey) - throws NoSuchAlgorithmException, InvalidKeyException, SignatureException { - - byte[] block = signedBlock.block; - org.biscuitsec.biscuit.crypto.PublicKey next_key = signedBlock.key; - byte[] signature = signedBlock.signature; + ByteBuffer algoBuf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); + algoBuf.putInt(Integer.valueOf(nextKey.getAlgorithm().getNumber())); + algoBuf.flip(); - var signatureLengthError = PublicKey.validateSignatureLength(publicKey.algorithm, signature.length); - if (signatureLengthError.isPresent()) { - return Left(signatureLengthError.get()); - } + byte[] block = signedBlock.getBlock(); + Signature sgr = KeyPair.generateSignature(publicKey.getAlgorithm()); + sgr.initVerify(publicKey.getKey()); + sgr.update(block); + if (signedBlock.getExternalSignature().isDefined()) { + sgr.update(signedBlock.getExternalSignature().get().getSignature()); + } + sgr.update(algoBuf); + sgr.update(nextKey.toBytes()); + byte[] payload = + BlockSignatureBuffer.getBufferSignature( + nextKey, block, signedBlock.getExternalSignature().toJavaOptional()); + if (!KeyPair.verify(publicKey, payload, signature)) { + return Left( + new Error.FormatError.Signature.InvalidSignature( + "signature error: Verification equation was not satisfied")); + } - ByteBuffer algo_buf = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN); - algo_buf.putInt(Integer.valueOf(next_key.algorithm.getNumber())); - algo_buf.flip(); + if (signedBlock.getExternalSignature().isDefined()) { + byte[] externalPayload = BlockSignatureBuffer.getBufferSignature(publicKey, block); + ExternalSignature externalSignature = signedBlock.getExternalSignature().get(); - Signature sgr = KeyPair.generateSignature(publicKey.algorithm); - sgr.initVerify(publicKey.key); - sgr.update(block); - if(signedBlock.externalSignature.isDefined()) { - sgr.update(signedBlock.externalSignature.get().signature); - } - sgr.update(algo_buf); - sgr.update(next_key.toBytes()); - byte[] payload = BlockSignatureBuffer.getBufferSignature(next_key, block, signedBlock.externalSignature.toJavaOptional()); - if (!KeyPair.verify(publicKey, payload, signature)) { - return Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")); - } + if (!KeyPair.verify( + externalSignature.getKey(), externalPayload, externalSignature.getSignature())) { + return Left( + new Error.FormatError.Signature.InvalidSignature( + "external signature error: Verification equation was not satisfied")); + } + } - if (signedBlock.externalSignature.isDefined()) { - byte[] externalPayload = BlockSignatureBuffer.getBufferSignature(publicKey, block); - ExternalSignature externalSignature = signedBlock.externalSignature.get(); + return Right(nextKey); + } - if (!KeyPair.verify(externalSignature.key, externalPayload, externalSignature.signature)) { - return Left(new Error.FormatError.Signature.InvalidSignature("external signature error: Verification equation was not satisfied")); - } - } + public Tuple2> extractBlocks(SymbolTable symbolTable) throws Error { + ArrayList> blockExternalKeys = + new ArrayList<>(); + Either authRes = + Block.fromBytes(this.authority.getBlock(), Option.none()); + if (authRes.isLeft()) { + throw authRes.getLeft(); + } + Block authority = authRes.get(); + for (org.biscuitsec.biscuit.crypto.PublicKey pk : authority.getPublicKeys()) { + symbolTable.insert(pk); + } + blockExternalKeys.add(Option.none()); - return Right(next_key); + for (String s : authority.getSymbolTable().symbols()) { + symbolTable.add(s); } - public Tuple2> extractBlocks(SymbolTable symbols) throws Error { - ArrayList> blockExternalKeys = new ArrayList<>(); - Either authRes = Block.from_bytes(this.authority.block, Option.none()); - if (authRes.isLeft()) { - throw authRes.getLeft(); - } - Block authority = authRes.get(); - for(org.biscuitsec.biscuit.crypto.PublicKey pk: authority.publicKeys()) { - symbols.insert(pk); - } + ArrayList blocks = new ArrayList<>(); + for (SignedBlock bdata : this.blocks) { + Option externalKey = Option.none(); + if (bdata.getExternalSignature().isDefined()) { + externalKey = Option.some(bdata.getExternalSignature().get().getKey()); + } + Either blockRes = Block.fromBytes(bdata.getBlock(), externalKey); + if (blockRes.isLeft()) { + throw blockRes.getLeft(); + } + Block block = blockRes.get(); + + // blocks with external signatures keep their own symbol table + if (bdata.getExternalSignature().isDefined()) { + // symbolTable.insert(bdata.externalSignature.get().key); + blockExternalKeys.add(Option.some(bdata.getExternalSignature().get().getKey())); + } else { blockExternalKeys.add(Option.none()); - - for (String s : authority.symbols().symbols) { - symbols.add(s); + for (String s : block.getSymbolTable().symbols()) { + symbolTable.add(s); } - - ArrayList blocks = new ArrayList<>(); - for (SignedBlock bdata : this.blocks) { - Option externalKey = Option.none(); - if(bdata.externalSignature.isDefined()) { - externalKey = Option.some(bdata.externalSignature.get().key); - } - Either blockRes = Block.from_bytes(bdata.block, externalKey); - if (blockRes.isLeft()) { - throw blockRes.getLeft(); - } - Block block = blockRes.get(); - - // blocks with external signatures keep their own symbol table - if(bdata.externalSignature.isDefined()) { - //symbols.insert(bdata.externalSignature.get().key); - blockExternalKeys.add(Option.some(bdata.externalSignature.get().key)); - } else { - blockExternalKeys.add(Option.none()); - for (String s : block.symbols().symbols) { - symbols.add(s); - } - for(org.biscuitsec.biscuit.crypto.PublicKey pk: block.publicKeys()) { - symbols.insert(pk); - } - } - - blocks.add(block); + for (org.biscuitsec.biscuit.crypto.PublicKey pk : block.getPublicKeys()) { + symbolTable.insert(pk); } + } - return new Tuple2<>(authority, blocks); + blocks.add(block); } - public Either seal() throws InvalidKeyException, NoSuchAlgorithmException, SignatureException { - if (this.proof.secretKey.isEmpty()) { - return Left(new Error.Sealed()); - } - - SignedBlock block; - if (this.blocks.isEmpty()) { - block = this.authority; - } else { - block = this.blocks.get(this.blocks.size() - 1); - } - - KeyPair secretKey = this.proof.secretKey.get(); - byte[] payload = BlockSignatureBuffer.getBufferSealedSignature(block.key, block.block, block.signature); - byte[] signature = secretKey.sign(payload); + return new Tuple2<>(authority, blocks); + } - this.proof.secretKey = Option.none(); - this.proof.signature = Option.some(signature); + public Either seal() + throws InvalidKeyException, NoSuchAlgorithmException, SignatureException { + if (this.proof.isSealed()) { + return Left(new Error.Sealed()); + } - return Right(null); + SignedBlock block; + if (this.blocks.isEmpty()) { + block = this.authority; + } else { + block = this.blocks.get(this.blocks.size() - 1); } - public List revocation_identifiers() { - ArrayList l = new ArrayList<>(); - l.add(this.authority.signature); + KeyPair secretKey = ((Proof.NextSecret) this.proof).secretKey(); + byte[] payload = + BlockSignatureBuffer.getBufferSealedSignature( + block.getKey(), block.getBlock(), block.getSignature()); + byte[] signature = secretKey.sign(payload); - for (SignedBlock block : this.blocks) { - l.add(block.signature); - } - return l; - } + this.proof = new Proof.FinalSignature(signature); - SerializedBiscuit(SignedBlock authority, List blocks, Proof proof) { - this.authority = authority; - this.blocks = blocks; - this.proof = proof; - this.root_key_id = Option.none(); - } + return Right(null); + } + + public List revocationIdentifiers() { + ArrayList l = new ArrayList<>(); + l.add(this.authority.getSignature()); - SerializedBiscuit(SignedBlock authority, List blocks, Proof proof, Option root_key_id) { - this.authority = authority; - this.blocks = blocks; - this.proof = proof; - this.root_key_id = root_key_id; + for (SignedBlock block : this.blocks) { + l.add(block.getSignature()); } + return l; + } + + SerializedBiscuit(SignedBlock authority, List blocks, Proof proof) { + this.authority = authority; + this.blocks = blocks; + this.proof = proof; + this.rootKeyId = Option.none(); + } + + SerializedBiscuit( + SignedBlock authority, List blocks, Proof proof, Option rootKeyId) { + this.authority = authority; + this.blocks = blocks; + this.proof = proof; + this.rootKeyId = rootKeyId; + } + + public SignedBlock getAuthority() { + return authority; + } + + public List getBlocks() { + return blocks; + } + + public Proof getProof() { + return proof; + } + + public Option getRootKeyId() { + return rootKeyId; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/format/SignedBlock.java b/src/main/java/org/biscuitsec/biscuit/token/format/SignedBlock.java index d7e6d4b1..77d61a7d 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/format/SignedBlock.java +++ b/src/main/java/org/biscuitsec/biscuit/token/format/SignedBlock.java @@ -1,18 +1,35 @@ package org.biscuitsec.biscuit.token.format; -import org.biscuitsec.biscuit.crypto.PublicKey; import io.vavr.control.Option; +import org.biscuitsec.biscuit.crypto.PublicKey; public class SignedBlock { - public byte[] block; - public PublicKey key; - public byte[] signature; - public Option externalSignature; + private byte[] block; + private PublicKey key; + private byte[] signature; + private Option externalSignature; + + public SignedBlock( + byte[] block, PublicKey key, byte[] signature, Option externalSignature) { + this.block = block; + this.key = key; + this.signature = signature; + this.externalSignature = externalSignature; + } + + public byte[] getBlock() { + return block; + } + + public PublicKey getKey() { + return key; + } + + public byte[] getSignature() { + return signature; + } - public SignedBlock(byte[] block, PublicKey key, byte[] signature, Option externalSignature) { - this.block = block; - this.key = key; - this.signature = signature; - this.externalSignature = externalSignature; - } + public Option getExternalSignature() { + return externalSignature; + } } diff --git a/src/main/java/org/biscuitsec/biscuit/token/format/package-info.java b/src/main/java/org/biscuitsec/biscuit/token/format/package-info.java index 18f19f34..ff94d20f 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/format/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/token/format/package-info.java @@ -1,4 +1,2 @@ -/** - * Serializing code - */ -package org.biscuitsec.biscuit.token.format; \ No newline at end of file +/** Serializing code */ +package org.biscuitsec.biscuit.token.format; diff --git a/src/main/java/org/biscuitsec/biscuit/token/package-info.java b/src/main/java/org/biscuitsec/biscuit/token/package-info.java index 5f91452a..5cbcbbe4 100644 --- a/src/main/java/org/biscuitsec/biscuit/token/package-info.java +++ b/src/main/java/org/biscuitsec/biscuit/token/package-info.java @@ -1,4 +1,2 @@ -/** - * Classes related to creating and verifying Biscuit tokens - */ -package org.biscuitsec.biscuit.token; \ No newline at end of file +/** Classes related to creating and verifying Biscuit tokens */ +package org.biscuitsec.biscuit.token; diff --git a/src/test/java/org/biscuitsec/biscuit/builder/BuilderTest.java b/src/test/java/org/biscuitsec/biscuit/builder/BuilderTest.java index 55708518..8c618694 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/BuilderTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/BuilderTest.java @@ -1,6 +1,18 @@ package org.biscuitsec.biscuit.builder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + import biscuit.format.schema.Schema; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.time.Instant; +import java.util.Arrays; +import java.util.Date; +import java.util.HashSet; +import java.util.Set; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; @@ -11,98 +23,119 @@ import org.biscuitsec.biscuit.token.builder.Utils; import org.junit.jupiter.api.Test; -import java.nio.charset.StandardCharsets; -import java.security.SecureRandom; -import java.time.Instant; -import java.util.Arrays; -import java.util.Date; -import java.util.HashSet; -import java.util.Set; - -import static org.junit.jupiter.api.Assertions.*; - public class BuilderTest { - @Test - public void testBuild() throws Error.Language, Error.SymbolTableOverlap, Error.FormatError { - SecureRandom rng = new SecureRandom(); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - SymbolTable symbols = Biscuit.default_symbol_table(); + @Test + public void testBuild() throws Error.Language, Error.SymbolTableOverlap, Error.FormatError { + SecureRandom rng = new SecureRandom(); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + SymbolTable symbolTable = Biscuit.defaultSymbolTable(); - Block authority_builder = new Block(); - authority_builder.add_fact(Utils.fact("revocation_id", Arrays.asList(Utils.date(Date.from(Instant.now()))))); - authority_builder.add_fact(Utils.fact("right", Arrays.asList(Utils.s("admin")))); - authority_builder.add_rule(Utils.constrained_rule("right", - Arrays.asList(Utils.s("namespace"), Utils.var("tenant"), Utils.var("namespace"), Utils.var("operation")), - Arrays.asList(Utils.pred("ns_operation", Arrays.asList(Utils.s("namespace"), Utils.var("tenant"), Utils.var("namespace"), Utils.var("operation")))), - Arrays.asList( - new Expression.Binary( - Expression.Op.Contains, - new Expression.Value(Utils.var("operation")), - new Expression.Value(new Term.Set(new HashSet<>(Arrays.asList( - Utils.s("create_topic"), - Utils.s("get_topic"), - Utils.s("get_topics") - ))))) - ) - )); - authority_builder.add_rule(Utils.constrained_rule("right", - Arrays.asList(Utils.s("topic"), Utils.var("tenant"), Utils.var("namespace"), Utils.var("topic"), Utils.var("operation")), - Arrays.asList(Utils.pred("topic_operation", Arrays.asList(Utils.s("topic"), Utils.var("tenant"), Utils.var("namespace"), Utils.var("topic"), Utils.var("operation")))), - Arrays.asList( - new Expression.Binary( - Expression.Op.Contains, - new Expression.Value(Utils.var("operation")), - new Expression.Value(new Term.Set(new HashSet<>(Arrays.asList( - Utils.s("lookup") - ))))) - ) - )); + Block authorityBuilder = new Block(); + authorityBuilder.addFact( + Utils.fact("revocation_id", Arrays.asList(Utils.date(Date.from(Instant.now()))))); + authorityBuilder.addFact(Utils.fact("right", Arrays.asList(Utils.str("admin")))); + authorityBuilder.addRule( + Utils.constrainedRule( + "right", + Arrays.asList( + Utils.str("namespace"), + Utils.var("tenant"), + Utils.var("namespace"), + Utils.var("operation")), + Arrays.asList( + Utils.pred( + "ns_operation", + Arrays.asList( + Utils.str("namespace"), + Utils.var("tenant"), + Utils.var("namespace"), + Utils.var("operation")))), + Arrays.asList( + new Expression.Binary( + Expression.Op.Contains, + new Expression.Value(Utils.var("operation")), + new Expression.Value( + new Term.Set( + new HashSet<>( + Arrays.asList( + Utils.str("create_topic"), + Utils.str("get_topic"), + Utils.str("get_topics"))))))))); + authorityBuilder.addRule( + Utils.constrainedRule( + "right", + Arrays.asList( + Utils.str("topic"), + Utils.var("tenant"), + Utils.var("namespace"), + Utils.var("topic"), + Utils.var("operation")), + Arrays.asList( + Utils.pred( + "topic_operation", + Arrays.asList( + Utils.str("topic"), + Utils.var("tenant"), + Utils.var("namespace"), + Utils.var("topic"), + Utils.var("operation")))), + Arrays.asList( + new Expression.Binary( + Expression.Op.Contains, + new Expression.Value(Utils.var("operation")), + new Expression.Value( + new Term.Set(new HashSet<>(Arrays.asList(Utils.str("lookup"))))))))); - org.biscuitsec.biscuit.token.Block authority = authority_builder.build(symbols); - Biscuit rootBiscuit = Biscuit.make(rng, root, authority); + org.biscuitsec.biscuit.token.Block authority = authorityBuilder.build(symbolTable); + Biscuit rootBiscuit = Biscuit.make(rng, root, authority); - System.out.println(rootBiscuit.print()); + System.out.println(rootBiscuit.print()); - assertNotNull(rootBiscuit); - } + assertNotNull(rootBiscuit); + } - @Test - public void testStringValueOfAStringTerm() { - assertEquals( "\"hello\"", new Term.Str("hello").toString() ); - } + @Test + public void testStringValueOfStringTerm() { + assertEquals("\"hello\"", new Term.Str("hello").toString()); + } - @Test - public void testStringValueOfAnIntegerTerm() { - assertEquals( "123", new Term.Integer(123).toString() ); - } + @Test + public void testStringValueOfIntegerTerm() { + assertEquals("123", new Term.Integer(123).toString()); + } - @Test - public void testStringValueOfAVariableTerm() { - assertEquals( "$hello", new Term.Variable("hello").toString() ); - } + @Test + public void testStringValueOfVariableTerm() { + assertEquals("$hello", new Term.Variable("hello").toString()); + } - @Test - public void testStringValueOfASetTerm() { - String actual = new Term.Set(Set.of(new Term.Str("a"), new Term.Str("b"), new Term.Integer((3)))).toString(); - assertTrue(actual.startsWith("["), "starts with ["); - assertTrue(actual.endsWith("]"), "ends with ]"); - assertTrue(actual.contains("\"a\""), "contains a"); - assertTrue(actual.contains("\"b\""), "contains b"); - assertTrue(actual.contains("3"), "contains 3"); - } + @Test + public void testStringValueOfSetTerm() { + String actual = + new Term.Set(Set.of(new Term.Str("a"), new Term.Str("b"), new Term.Integer((3)))) + .toString(); + assertTrue(actual.startsWith("["), "starts with ["); + assertTrue(actual.endsWith("]"), "ends with ]"); + assertTrue(actual.contains("\"a\""), "contains a"); + assertTrue(actual.contains("\"b\""), "contains b"); + assertTrue(actual.contains("3"), "contains 3"); + } - @Test - public void testStringValueOfAByteArrayTermIsJustTheArrayReferenceNotTheContents() { - String string = new Term.Bytes("Hello".getBytes(StandardCharsets.UTF_8)).toString(); - assertTrue(string.startsWith("hex:"), "starts with hex prefix"); - } + @Test + public void testStringValueOfByteArrayTermIsJustTheArrayReferenceNotTheContents() { + String string = new Term.Bytes("Hello".getBytes(StandardCharsets.UTF_8)).toString(); + assertTrue(string.startsWith("hex:"), "starts with hex prefix"); + } - @Test - public void testArrayValueIsCopy() { - byte[] someBytes = "Hello".getBytes(StandardCharsets.UTF_8); - Term.Bytes term = new Term.Bytes(someBytes); - assertTrue(Arrays.equals(someBytes, term.getValue()), "same content"); - assertNotEquals(System.identityHashCode(someBytes), System.identityHashCode(term.getValue()), "different objects"); - } + @Test + public void testArrayValueIsCopy() { + byte[] someBytes = "Hello".getBytes(StandardCharsets.UTF_8); + Term.Bytes term = new Term.Bytes(someBytes); + assertTrue(Arrays.equals(someBytes, term.getValue()), "same content"); + assertNotEquals( + System.identityHashCode(someBytes), + System.identityHashCode(term.getValue()), + "different objects"); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java index 22184356..99d1fd38 100644 --- a/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java +++ b/src/test/java/org/biscuitsec/biscuit/builder/parser/ParserTest.java @@ -1,513 +1,530 @@ package org.biscuitsec.biscuit.builder.parser; +import static org.biscuitsec.biscuit.datalog.Check.Kind.ONE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + import biscuit.format.schema.Schema; +import io.vavr.Tuple2; +import io.vavr.control.Either; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.datalog.TemporarySymbolTable; import org.biscuitsec.biscuit.datalog.expressions.Op; -import org.biscuitsec.biscuit.token.Biscuit; +import org.biscuitsec.biscuit.token.builder.Block; +import org.biscuitsec.biscuit.token.builder.Check; +import org.biscuitsec.biscuit.token.builder.Expression; +import org.biscuitsec.biscuit.token.builder.Fact; +import org.biscuitsec.biscuit.token.builder.Predicate; +import org.biscuitsec.biscuit.token.builder.Rule; +import org.biscuitsec.biscuit.token.builder.Scope; +import org.biscuitsec.biscuit.token.builder.Term; +import org.biscuitsec.biscuit.token.builder.Utils; import org.biscuitsec.biscuit.token.builder.parser.Error; import org.biscuitsec.biscuit.token.builder.parser.Parser; -import io.vavr.Tuple2; -import io.vavr.control.Either; -import org.biscuitsec.biscuit.token.builder.*; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import static org.biscuitsec.biscuit.datalog.Check.Kind.One; -import static org.junit.jupiter.api.Assertions.*; - -import java.util.*; - class ParserTest { - @Test - void testName() { - Either> res = Parser.name("operation(read)"); - assertEquals(Either.right(new Tuple2<>("(read)", "operation")), res); - } - - @Test - void testString() { - Either> res = Parser.string("\"file1 a hello - 123_\""); - assertEquals(Either.right(new Tuple2<>("", (Term.Str) Utils.string("file1 a hello - 123_"))), res); - } - - @Test - void testInteger() { - Either> res = Parser.integer("123"); - assertEquals(Either.right(new Tuple2<>("", (Term.Integer) Utils.integer(123))), res); - - Either> res2 = Parser.integer("-42"); - assertEquals(Either.right(new Tuple2<>("", (Term.Integer) Utils.integer(-42))), res2); - } - - @Test - void testDate() { - Either> res = Parser.date("2019-12-02T13:49:53Z,"); - assertEquals(Either.right(new Tuple2<>(",", new Term.Date(1575294593))), res); - } - - @Test - void testVariable() { - Either> res = Parser.variable("$name"); - assertEquals(Either.right(new Tuple2<>("", (Term.Variable) Utils.var("name"))), res); - } - - @Test - void testFact() throws org.biscuitsec.biscuit.error.Error.Language { - Either> res = Parser.fact("right( \"file1\", \"read\" )"); - assertEquals(Either.right(new Tuple2<>("", - Utils.fact("right", Arrays.asList(Utils.string("file1"), Utils.s("read"))))), - res); - - Either> res2 = Parser.fact("right( $var, \"read\" )"); - assertEquals(Either.left(new Error("$var, \"read\" )", "closing parens not found")), - res2); - - Either> res3 = Parser.fact("date(2019-12-02T13:49:53Z)"); - assertEquals(Either.right(new Tuple2<>("", - Utils.fact("date", List.of(new Term.Date(1575294593))))), - res3); - - Either> res4 = Parser.fact("n1:right( \"file1\", \"read\" )"); - assertEquals(Either.right(new Tuple2<>("", - Utils.fact("n1:right", Arrays.asList(Utils.string("file1"), Utils.s("read"))))), - res4); - } - - @Test - void testRule() { - Either> res = - Parser.rule("right($resource, \"read\") <- resource($resource), operation(\"read\")"); - assertEquals(Either.right(new Tuple2<>("", - Utils.rule("right", - Arrays.asList(Utils.var("resource"), Utils.s("read")), - Arrays.asList( - Utils.pred("resource", List.of(Utils.var("resource"))), - Utils.pred("operation", List.of(Utils.s("read")))) - ))), - res); - } - - @Test - void testRuleWithExpression() { - Either> res = - Parser.rule("valid_date(\"file1\") <- time($0 ), resource( \"file1\"), $0 <= 2019-12-04T09:46:41Z"); - assertEquals(Either.right(new Tuple2<>("", - Utils.constrained_rule("valid_date", - List.of(Utils.string("file1")), - Arrays.asList( - Utils.pred("time", List.of(Utils.var("0"))), - Utils.pred("resource", List.of(Utils.string("file1"))) - ), - List.of( - new Expression.Binary( - Expression.Op.LessOrEqual, - new Expression.Value(Utils.var("0")), - new Expression.Value(new Term.Date(1575452801))) - ) - ))), - res); - } - - @Test - void testRuleWithExpressionOrdering() { - Either> res = - Parser.rule("valid_date(\"file1\") <- time($0 ), $0 <= 2019-12-04T09:46:41Z, resource(\"file1\")"); - assertEquals(Either.right(new Tuple2<>("", - Utils.constrained_rule("valid_date", - List.of(Utils.string("file1")), - Arrays.asList( - Utils.pred("time", List.of(Utils.var("0"))), - Utils.pred("resource", List.of(Utils.string("file1"))) - ), - List.of( - new Expression.Binary( - Expression.Op.LessOrEqual, - new Expression.Value(Utils.var("0")), - new Expression.Value(new Term.Date(1575452801))) - ) - ))), - res); - } - - @Test - void expressionIntersectionAndContainsTest() { - Either> res = - Parser.expression("[1, 2, 3].intersection([1, 2]).contains(1)"); - - assertEquals(Either.right(new Tuple2<>("", - new Expression.Binary( - Expression.Op.Contains, + @Test + void testName() { + Either> res = Parser.name("operation(read)"); + assertEquals(Either.right(new Tuple2<>("(read)", "operation")), res); + } + + @Test + void testString() { + Either> res = Parser.string("\"file1 a hello - 123_\""); + assertEquals( + Either.right(new Tuple2<>("", (Term.Str) Utils.string("file1 a hello - 123_"))), res); + } + + @Test + void testInteger() { + Either> res = Parser.integer("123"); + assertEquals(Either.right(new Tuple2<>("", (Term.Integer) Utils.integer(123))), res); + + Either> res2 = Parser.integer("-42"); + assertEquals(Either.right(new Tuple2<>("", (Term.Integer) Utils.integer(-42))), res2); + } + + @Test + void testDate() { + Either> res = Parser.date("2019-12-02T13:49:53Z,"); + assertEquals(Either.right(new Tuple2<>(",", new Term.Date(1575294593))), res); + } + + @Test + void testVariable() { + Either> res = Parser.variable("$name"); + assertEquals(Either.right(new Tuple2<>("", (Term.Variable) Utils.var("name"))), res); + } + + @Test + void testFact() throws org.biscuitsec.biscuit.error.Error.Language { + Either> res = Parser.fact("right( \"file1\", \"read\" )"); + assertEquals( + Either.right( + new Tuple2<>( + "", Utils.fact("right", Arrays.asList(Utils.string("file1"), Utils.str("read"))))), + res); + + Either> res2 = Parser.fact("right( $var, \"read\" )"); + assertEquals(Either.left(new Error("$var, \"read\" )", "closing parens not found")), res2); + + Either> res3 = Parser.fact("date(2019-12-02T13:49:53Z)"); + assertEquals( + Either.right(new Tuple2<>("", Utils.fact("date", List.of(new Term.Date(1575294593))))), + res3); + + Either> res4 = Parser.fact("n1:right( \"file1\", \"read\" )"); + assertEquals( + Either.right( + new Tuple2<>( + "", + Utils.fact("n1:right", Arrays.asList(Utils.string("file1"), Utils.str("read"))))), + res4); + } + + @Test + void testRule() { + Either> res = + Parser.rule("right($resource, \"read\") <- resource($resource), operation(\"read\")"); + assertEquals( + Either.right( + new Tuple2<>( + "", + Utils.rule( + "right", + Arrays.asList(Utils.var("resource"), Utils.str("read")), + Arrays.asList( + Utils.pred("resource", List.of(Utils.var("resource"))), + Utils.pred("operation", List.of(Utils.str("read"))))))), + res); + } + + @Test + void testRuleWithExpression() { + Either> res = + Parser.rule( + "valid_date(\"file1\") <- time($0 ), resource( \"file1\"), $0 <= 2019-12-04T09:46:41Z"); + assertEquals( + Either.right( + new Tuple2<>( + "", + Utils.constrainedRule( + "valid_date", + List.of(Utils.string("file1")), + Arrays.asList( + Utils.pred("time", List.of(Utils.var("0"))), + Utils.pred("resource", List.of(Utils.string("file1")))), + List.of( + new Expression.Binary( + Expression.Op.LessOrEqual, + new Expression.Value(Utils.var("0")), + new Expression.Value(new Term.Date(1575452801))))))), + res); + } + + @Test + void testRuleWithExpressionOrdering() { + Either> res = + Parser.rule( + "valid_date(\"file1\") <- time($0 ), $0 <= 2019-12-04T09:46:41Z, resource(\"file1\")"); + assertEquals( + Either.right( + new Tuple2<>( + "", + Utils.constrainedRule( + "valid_date", + List.of(Utils.string("file1")), + Arrays.asList( + Utils.pred("time", List.of(Utils.var("0"))), + Utils.pred("resource", List.of(Utils.string("file1")))), + List.of( new Expression.Binary( - Expression.Op.Intersection, - new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), - new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) - ), - new Expression.Value(Utils.integer(1)) - ))), res); - } - - @Test - void expressionIntersectionAndContainsAndLengthEqualsTest() { - Either> res = - Parser.expression("[1, 2, 3].intersection([1, 2]).length() == 2"); - - assertEquals(Either.right(new Tuple2<>("", + Expression.Op.LessOrEqual, + new Expression.Value(Utils.var("0")), + new Expression.Value(new Term.Date(1575452801))))))), + res); + } + + @Test + void expressionIntersectionAndContainsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).contains(1)"); + + assertEquals( + Either.right( + new Tuple2<>( + "", + new Expression.Binary( + Expression.Op.Contains, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value( + Utils.set( + new HashSet<>( + Arrays.asList( + Utils.integer(1), Utils.integer(2), Utils.integer(3))))), + new Expression.Value( + Utils.set( + new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2)))))), + new Expression.Value(Utils.integer(1))))), + res); + } + + @Test + void expressionIntersectionAndContainsAndLengthEqualsTest() { + Either> res = + Parser.expression("[1, 2, 3].intersection([1, 2]).length() == 2"); + + assertEquals( + Either.right( + new Tuple2<>( + "", new Expression.Binary( - Expression.Op.Equal, - new Expression.Unary( - Expression.Op.Length, + Expression.Op.Equal, + new Expression.Unary( + Expression.Op.Length, + new Expression.Binary( + Expression.Op.Intersection, + new Expression.Value( + Utils.set( + new HashSet<>( + Arrays.asList( + Utils.integer(1), + Utils.integer(2), + Utils.integer(3))))), + new Expression.Value( + Utils.set( + new HashSet<>( + Arrays.asList(Utils.integer(1), Utils.integer(2))))))), + new Expression.Value(Utils.integer(2))))), + res); + } + + @Test + void testNegatePrecedence() { + Either> res = Parser.check("check if !false && true"); + assertEquals( + Either.right( + new Tuple2<>( + "", + Utils.check( + Utils.constrainedRule( + "query", + new ArrayList<>(), + new ArrayList<>(), + List.of( new Expression.Binary( - Expression.Op.Intersection, - new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2), Utils.integer(3))))), - new Expression.Value(Utils.set(new HashSet<>(Arrays.asList(Utils.integer(1), Utils.integer(2))))) - ) - ), - new Expression.Value(Utils.integer(2)) - ))), res); - } - - @Test - void testNegatePrecedence() { - Either> res = - Parser.check("check if !false && true"); - assertEquals(Either.right(new Tuple2<>("", - Utils.check( - Utils.constrained_rule("query", - new ArrayList<>(), - new ArrayList<>(), - List.of( - new Expression.Binary( - Expression.Op.And, - new Expression.Unary( - Expression.Op.Negate, - new Expression.Value(new Term.Bool(false)) - ), - new Expression.Value(new Term.Bool(true)) - ) - ) - )))), - res); - } - - @Test - void ruleWithFreeExpressionVariables() { - Either> res = - Parser.rule("right($0) <- resource($0), operation(\"read\"), $test"); - assertEquals( - Either.left( - new Error(" resource($0), operation(\"read\"), $test", - "rule head or expressions contains variables that are not used in predicates of the rule's body: [test]") - ), - res); - } - - @Test - void testRuleWithScope() { - Either> res = - Parser.rule("valid_date(\"file1\") <- resource(\"file1\") trusting ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db, authority "); - assertEquals(Either.right(new Tuple2<>("", - new Rule( - new Predicate( - "valid_date", - List.of(Utils.string("file1") - )), - Arrays.asList( - Utils.pred("resource", List.of(Utils.string("file1"))) - ), - new ArrayList<>(), - Arrays.asList( - Scope.publicKey(new PublicKey(Schema.PublicKey.Algorithm.Ed25519, "6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db")), - Scope.authority() - ) - ))), - res); - } - - @Test - void testCheck() { - Either> res = - Parser.check("check if resource($0), operation(\"read\") or admin()"); - assertEquals(Either.right(new Tuple2<>("", new Check( - One, - Arrays.asList( - Utils.rule("query", + Expression.Op.And, + new Expression.Unary( + Expression.Op.Negate, + new Expression.Value(new Term.Bool(false))), + new Expression.Value(new Term.Bool(true)))))))), + res); + } + + @Test + void ruleWithFreeExpressionVariables() { + Either> res = + Parser.rule("right($0) <- resource($0), operation(\"read\"), $test"); + assertEquals( + Either.left( + new Error( + " resource($0), operation(\"read\"), $test", + "rule head or expressions contains variables that are not used in predicates of the" + + " rule's body: [test]")), + res); + } + + @Test + void testRuleWithScope() { + Either> res = + Parser.rule( + "valid_date(\"file1\") <- resource(\"file1\") trusting" + + " ed25519/6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db," + + " authority "); + Rule refRule = + new Rule( + new Predicate("valid_date", List.of(Utils.string("file1"))), + List.of(Utils.pred("resource", List.of(Utils.string("file1")))), + new ArrayList<>(), + Arrays.asList( + Scope.publicKey( + new PublicKey( + Schema.PublicKey.Algorithm.Ed25519, + "6e9e6d5a75cf0c0e87ec1256b4dfed0ca3ba452912d213fcc70f8516583db9db")), + Scope.authority())); + assertEquals(Either.right(new Tuple2<>("", refRule)), res); + } + + @Test + void testCheck() { + Either> res = + Parser.check("check if resource($0), operation(\"read\") or admin()"); + assertEquals( + Either.right( + new Tuple2<>( + "", + new Check( + ONE, + Arrays.asList( + Utils.rule( + "query", new ArrayList<>(), Arrays.asList( - Utils.pred("resource", List.of(Utils.var("0"))), - Utils.pred("operation", List.of(Utils.s("read"))) - ) - ), - Utils.rule("query", + Utils.pred("resource", List.of(Utils.var("0"))), + Utils.pred("operation", List.of(Utils.str("read"))))), + Utils.rule( + "query", new ArrayList<>(), - List.of( - Utils.pred("admin", List.of()) - ) - ) - )))), - res); - } - - @Test - void testExpression() { - Either> res = - Parser.expression(" -1 "); - - assertEquals(new Tuple2("", - new Expression.Value(Utils.integer(-1))), - res.get()); - - Either> res2 = - Parser.expression(" $0 <= 2019-12-04T09:46:41+00:00"); - - assertEquals(new Tuple2("", - new Expression.Binary( - Expression.Op.LessOrEqual, - new Expression.Value(Utils.var("0")), - new Expression.Value(new Term.Date(1575452801)))), - res2.get()); - - Either> res3 = - Parser.expression(" 1 < $test + 2 "); - - assertEquals(Either.right(new Tuple2("", - new Expression.Binary( - Expression.Op.LessThan, - new Expression.Value(Utils.integer(1)), - new Expression.Binary( - Expression.Op.Add, - new Expression.Value(Utils.var("test")), - new Expression.Value(Utils.integer(2)) - ) - ) - )), - res3); - - SymbolTable s3 = new SymbolTable(); - long test = s3.insert("test"); - Assertions.assertEquals( - Arrays.asList( - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Variable(test)), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), - new Op.Binary(Op.BinaryOp.Add), - new Op.Binary(Op.BinaryOp.LessThan) - ), - res3.get()._2.convert(s3).getOps() - ); - - Either> res4 = - Parser.expression(" 2 < $test && $var2.starts_with(\"test\") && true "); - - assertEquals(Either.right(new Tuple2("", - new Expression.Binary( - Expression.Op.And, - new Expression.Binary( - Expression.Op.And, - new Expression.Binary( - Expression.Op.LessThan, - new Expression.Value(Utils.integer(2)), - new Expression.Value(Utils.var("test")) - ), - new Expression.Binary( - Expression.Op.Prefix, - new Expression.Value(Utils.var("var2")), - new Expression.Value(Utils.string("test")) - ) - ), - new Expression.Value(new Term.Bool(true)) - ) - )), - res4); - - Either> res5 = - Parser.expression(" [ \"abc\", \"def\" ].contains($operation) "); - - HashSet s = new HashSet<>(); - s.add(Utils.s("abc")); - s.add(Utils.s("def")); - - assertEquals(Either.right(new Tuple2("", + List.of(Utils.pred("admin", List.of()))))))), + res); + } + + @Test + void testExpression() { + Either> res = Parser.expression(" -1 "); + + assertEquals( + new Tuple2("", new Expression.Value(Utils.integer(-1))), res.get()); + + Either> res2 = + Parser.expression(" $0 <= 2019-12-04T09:46:41+00:00"); + + assertEquals( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.LessOrEqual, + new Expression.Value(Utils.var("0")), + new Expression.Value(new Term.Date(1575452801)))), + res2.get()); + + Either> res3 = Parser.expression(" 1 < $test + 2 "); + + assertEquals( + Either.right( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.LessThan, + new Expression.Value(Utils.integer(1)), + new Expression.Binary( + Expression.Op.Add, + new Expression.Value(Utils.var("test")), + new Expression.Value(Utils.integer(2)))))), + res3); + + SymbolTable s3 = new SymbolTable(); + long test = s3.insert("test"); + assertEquals( + Arrays.asList( + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Variable(test)), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), + new Op.Binary(Op.BinaryOp.Add), + new Op.Binary(Op.BinaryOp.LessThan)), + res3.get()._2.convert(s3).getOps()); + + Either> res4 = + Parser.expression(" 2 < $test && $var2.starts_with(\"test\") && true "); + + assertEquals( + Either.right( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.And, + new Expression.Binary( + Expression.Op.And, new Expression.Binary( - Expression.Op.Contains, - new Expression.Value(Utils.set(s)), - new Expression.Value(Utils.var("operation")) - ) - )), - res5); - } - - @Test - void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { - Either> res = - Parser.expression(" 1 + 2 * 3 "); - - assertEquals(Either.right(new Tuple2("", + Expression.Op.LessThan, + new Expression.Value(Utils.integer(2)), + new Expression.Value(Utils.var("test"))), new Expression.Binary( - Expression.Op.Add, - new Expression.Value(Utils.integer(1)), - new Expression.Binary( - Expression.Op.Mul, - new Expression.Value(Utils.integer(2)), - new Expression.Value(Utils.integer(3)) - ) - ) - )), - res); - - Expression e = res.get()._2; - SymbolTable s = new SymbolTable(); - - org.biscuitsec.biscuit.datalog.expressions.Expression ex = e.convert(s); - - assertEquals( - Arrays.asList( - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(3)), - new Op.Binary(Op.BinaryOp.Mul), - new Op.Binary(Op.BinaryOp.Add) - ), - ex.getOps() - ); - - Map variables = new HashMap<>(); - org.biscuitsec.biscuit.datalog.Term value = ex.evaluate(variables, new TemporarySymbolTable(s)); - assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(7), value); - assertEquals("1 + 2 * 3", ex.print(s).get()); - - - Either> res2 = - Parser.expression(" (1 + 2) * 3 "); - - assertEquals(Either.right(new Tuple2("", + Expression.Op.Prefix, + new Expression.Value(Utils.var("var2")), + new Expression.Value(Utils.string("test")))), + new Expression.Value(new Term.Bool(true))))), + res4); + + Either> res5 = + Parser.expression(" [ \"abc\", \"def\" ].contains($operation) "); + + HashSet s = new HashSet<>(); + s.add(Utils.str("abc")); + s.add(Utils.str("def")); + + assertEquals( + Either.right( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.Contains, + new Expression.Value(Utils.set(s)), + new Expression.Value(Utils.var("operation"))))), + res5); + } + + @Test + void testParens() throws org.biscuitsec.biscuit.error.Error.Execution { + Either> res = Parser.expression(" 1 + 2 * 3 "); + + assertEquals( + Either.right( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.Add, + new Expression.Value(Utils.integer(1)), + new Expression.Binary( + Expression.Op.Mul, + new Expression.Value(Utils.integer(2)), + new Expression.Value(Utils.integer(3)))))), + res); + + Expression e = res.get()._2; + SymbolTable s = new SymbolTable(); + + org.biscuitsec.biscuit.datalog.expressions.Expression ex = e.convert(s); + + assertEquals( + Arrays.asList( + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(3)), + new Op.Binary(Op.BinaryOp.Mul), + new Op.Binary(Op.BinaryOp.Add)), + ex.getOps()); + + Map variables = new HashMap<>(); + org.biscuitsec.biscuit.datalog.Term value = ex.evaluate(variables, new TemporarySymbolTable(s)); + assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(7), value); + assertEquals("1 + 2 * 3", ex.print(s).get()); + + Either> res2 = Parser.expression(" (1 + 2) * 3 "); + + assertEquals( + Either.right( + new Tuple2( + "", + new Expression.Binary( + Expression.Op.Mul, + new Expression.Unary( + Expression.Op.Parens, new Expression.Binary( - Expression.Op.Mul, - new Expression.Unary( - Expression.Op.Parens, - new Expression.Binary( - Expression.Op.Add, - new Expression.Value(Utils.integer(1)), - new Expression.Value(Utils.integer(2)) - )) - , - new Expression.Value(Utils.integer(3)) - ) - )), - res2); - - Expression e2 = res2.get()._2; - SymbolTable s2 = new SymbolTable(); - - org.biscuitsec.biscuit.datalog.expressions.Expression ex2 = e2.convert(s2); - - assertEquals( - Arrays.asList( - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), - new Op.Binary(Op.BinaryOp.Add), - new Op.Unary(Op.UnaryOp.Parens), - new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(3)), - new Op.Binary(Op.BinaryOp.Mul) - ), - ex2.getOps() - ); - - Map variables2 = new HashMap<>(); - org.biscuitsec.biscuit.datalog.Term value2 = ex2.evaluate(variables2, new TemporarySymbolTable(s2)); - assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(9), value2); - assertEquals("(1 + 2) * 3", ex2.print(s2).get()); - } - - @Test - void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { - - String l1 = "fact1(1, 2)"; - String l2 = "fact2(\"2\")"; - String l3 = "rule1(2) <- fact2(\"2\")"; - String l4 = "check if rule1(2)"; - String toParse = String.join(";", Arrays.asList(l1, l2, l3, l4)); - - Either>, Block> output = Parser.datalog(1, toParse); - assertTrue(output.isRight()); - - Block validBlock = new Block(); - validBlock.add_fact(l1); - validBlock.add_fact(l2); - validBlock.add_rule(l3); - validBlock.add_check(l4); - - output.forEach(block -> - assertEquals(block, validBlock) - ); - } - - @Test - void testDatalogSucceedsArrays() throws org.biscuitsec.biscuit.error.Error.Parser { - SymbolTable symbols = Biscuit.default_symbol_table(); - - String l1 = "check if [2, 3].union([2])"; - String toParse = String.join(";", List.of(l1)); - - Either>, Block> output = Parser.datalog(1, toParse); - assertTrue(output.isRight()); - - Block validBlock = new Block(); - validBlock.add_check(l1); - - output.forEach(block -> - assertEquals(block, validBlock) - ); - } - - @Test - void testDatalogSucceedsArraysContains() throws org.biscuitsec.biscuit.error.Error.Parser { - SymbolTable symbols = Biscuit.default_symbol_table(); - - String l1 = "check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z)"; - String toParse = String.join(";", List.of(l1)); - - Either>, Block> output = Parser.datalog(1, toParse); - assertTrue(output.isRight()); - - Block validBlock = new Block(); - validBlock.add_check(l1); - - output.forEach(block -> - assertEquals(block, validBlock) - ); - } - - @Test - void testDatalogFailed() { - SymbolTable symbols = Biscuit.default_symbol_table(); - - String l1 = "fact(1)"; - String l2 = "check fact(1)"; // typo missing "if" - String toParse = String.join(";", Arrays.asList(l1, l2)); - - Either>, Block> output = Parser.datalog(1, toParse); - assertTrue(output.isLeft()); - } - - @Test - void testDatalogRemoveComment() throws org.biscuitsec.biscuit.error.Error.Parser { - SymbolTable symbols = Biscuit.default_symbol_table(); - - String l0 = "// test comment"; - String l1 = "fact1(1, 2);"; - String l2 = "fact2(\"2\");"; - String l3 = "rule1(2) <- fact2(\"2\");"; - String l4 = "// another comment"; - String l5 = "/* test multiline"; - String l6 = "comment */ check if rule1(2);"; - String l7 = " /* another multiline"; - String l8 = "comment */"; - String toParse = String.join("", Arrays.asList(l0, l1, l2, l3, l4, l5, l6, l7, l8)); - - Either>, Block> output = Parser.datalog(1, toParse); - assertTrue(output.isRight()); - } + Expression.Op.Add, + new Expression.Value(Utils.integer(1)), + new Expression.Value(Utils.integer(2)))), + new Expression.Value(Utils.integer(3))))), + res2); + + Expression e2 = res2.get()._2; + SymbolTable s2 = new SymbolTable(); + + org.biscuitsec.biscuit.datalog.expressions.Expression ex2 = e2.convert(s2); + + assertEquals( + Arrays.asList( + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(1)), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(2)), + new Op.Binary(Op.BinaryOp.Add), + new Op.Unary(Op.UnaryOp.Parens), + new Op.Value(new org.biscuitsec.biscuit.datalog.Term.Integer(3)), + new Op.Binary(Op.BinaryOp.Mul)), + ex2.getOps()); + + Map variables2 = new HashMap<>(); + org.biscuitsec.biscuit.datalog.Term value2 = + ex2.evaluate(variables2, new TemporarySymbolTable(s2)); + assertEquals(new org.biscuitsec.biscuit.datalog.Term.Integer(9), value2); + assertEquals("(1 + 2) * 3", ex2.print(s2).get()); + } + + @Test + void testDatalogSucceeds() throws org.biscuitsec.biscuit.error.Error.Parser { + + String l1 = "fact1(1, 2)"; + String l2 = "fact2(\"2\")"; + String l3 = "rule1(2) <- fact2(\"2\")"; + String l4 = "check if rule1(2)"; + String toParse = String.join(";", Arrays.asList(l1, l2, l3, l4)); + + Either>, Block> output = Parser.datalog(1, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(); + validBlock.addFact(l1); + validBlock.addFact(l2); + validBlock.addRule(l3); + validBlock.addCheck(l4); + + output.forEach(block -> assertEquals(block, validBlock)); + } + + @Test + void testDatalogSucceedsArrays() throws org.biscuitsec.biscuit.error.Error.Parser { + String l1 = "check if [2, 3].union([2])"; + String toParse = String.join(";", List.of(l1)); + + Either>, Block> output = Parser.datalog(1, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(); + validBlock.addCheck(l1); + + output.forEach(block -> assertEquals(block, validBlock)); + } + + @Test + void testDatalogSucceedsArraysContains() throws org.biscuitsec.biscuit.error.Error.Parser { + String l1 = + "check if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z)"; + String toParse = String.join(";", List.of(l1)); + + Either>, Block> output = Parser.datalog(1, toParse); + assertTrue(output.isRight()); + + Block validBlock = new Block(); + validBlock.addCheck(l1); + + output.forEach(block -> assertEquals(block, validBlock)); + } + + @Test + void testDatalogFailed() { + String l1 = "fact(1)"; + String l2 = "check fact(1)"; // typo missing "if" + String toParse = String.join(";", Arrays.asList(l1, l2)); + + Either>, Block> output = Parser.datalog(1, toParse); + assertTrue(output.isLeft()); + } + + @Test + void testDatalogRemoveComment() { + + String l0 = "// test comment"; + String l1 = "fact1(1, 2);"; + String l2 = "fact2(\"2\");"; + String l3 = "rule1(2) <- fact2(\"2\");"; + String l4 = "// another comment"; + String l5 = "/* test multiline"; + String l6 = "comment */ check if rule1(2);"; + String l7 = " /* another multiline"; + String l8 = "comment */"; + String toParse = String.join("", Arrays.asList(l0, l1, l2, l3, l4, l5, l6, l7, l8)); + + Either>, Block> output = Parser.datalog(1, toParse); + assertTrue(output.isRight()); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/crypto/SignatureTest.java b/src/test/java/org/biscuitsec/biscuit/crypto/SignatureTest.java index b7f3cb67..69a8668a 100644 --- a/src/test/java/org/biscuitsec/biscuit/crypto/SignatureTest.java +++ b/src/test/java/org/biscuitsec/biscuit/crypto/SignatureTest.java @@ -1,131 +1,102 @@ package org.biscuitsec.biscuit.crypto; -import biscuit.format.schema.Schema; +import static io.vavr.API.Right; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import biscuit.format.schema.Schema; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.security.SignatureException; - -import static biscuit.format.schema.Schema.PublicKey.Algorithm.*; -import static io.vavr.API.Left; -import static io.vavr.API.Right; - import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.token.Biscuit; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; - /** * @serial exclude */ public class SignatureTest { - @Test - public void testSerialize() { - testSerialize(Ed25519, 32); - testSerialize(SECP256R1, 33); // compressed - 0x02 or 0x03 prefix byte, 32 bytes for X coordinate - } - - @Test - public void testThreeMessages() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { - testThreeMessages(Ed25519); - testThreeMessages(SECP256R1); - } - - @Test - public void testChangeMessages() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { - testChangeMessages(Ed25519); - testChangeMessages(SECP256R1); - } - - private static void testSerialize(Schema.PublicKey.Algorithm algorithm, int expectedPublicKeyLength) { - byte[] seed = {1, 2, 3, 4}; - SecureRandom rng = new SecureRandom(seed); - - KeyPair keypair = KeyPair.generate(algorithm, rng); - PublicKey pubkey = keypair.public_key(); - - byte[] serializedSecretKey = keypair.toBytes(); - byte[] serializedPublicKey = pubkey.toBytes(); - - KeyPair deserializedSecretKey = KeyPair.generate(algorithm, serializedSecretKey); - PublicKey deserializedPublicKey = new PublicKey(algorithm, serializedPublicKey); - - assertEquals(32, serializedSecretKey.length); - assertEquals(expectedPublicKeyLength, serializedPublicKey.length); - - System.out.println(keypair.toHex()); - System.out.println(deserializedSecretKey.toHex()); - assertArrayEquals(keypair.toBytes(), deserializedSecretKey.toBytes()); - - System.out.println(pubkey.toHex()); - System.out.println(deserializedPublicKey.toHex()); - assertEquals(pubkey.toHex(), deserializedPublicKey.toHex()); - } - - private static void testChangeMessages(Schema.PublicKey.Algorithm algorithm) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - String message1 = "hello"; - KeyPair root = KeyPair.generate(algorithm, rng); - KeyPair keypair2 = KeyPair.generate(algorithm, rng); - Token token1 = new Token(root, message1.getBytes(), keypair2); - assertEquals(Right(null), token1.verify(new PublicKey(algorithm, root.public_key().key))); - - String message2 = "world"; - KeyPair keypair3 = KeyPair.generate(algorithm, rng); - Token token2 = token1.append(keypair3, message2.getBytes()); - token2.blocks.set(1, "you".getBytes()); - assertEquals(Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")), - token2.verify(new PublicKey(algorithm, root.public_key().key))); - - String message3 = "!!"; - KeyPair keypair4 = KeyPair.generate(algorithm, rng); - Token token3 = token2.append(keypair4, message3.getBytes()); - assertEquals(Left(new Error.FormatError.Signature.InvalidSignature("signature error: Verification equation was not satisfied")), - token3.verify(new PublicKey(algorithm, root.public_key().key))); - } - - private static void testThreeMessages(Schema.PublicKey.Algorithm algorithm) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - String message1 = "hello"; - KeyPair root = KeyPair.generate(algorithm, rng); - KeyPair keypair2 = KeyPair.generate(algorithm, rng); - System.out.println("root key: " + root.toHex()); - System.out.println("keypair2: " + keypair2.toHex()); - System.out.println("root key public: " + root.public_key().toHex()); - System.out.println("keypair2 public: " + keypair2.public_key().toHex()); - - Token token1 = new Token(root, message1.getBytes(), keypair2); - assertEquals(Right(null), token1.verify(root.public_key())); - - String message2 = "world"; - KeyPair keypair3 = KeyPair.generate(algorithm, rng); - Token token2 = token1.append(keypair3, message2.getBytes()); - assertEquals(Right(null), token2.verify(root.public_key())); - - String message3 = "!!"; - KeyPair keypair4 = KeyPair.generate(algorithm, rng); - Token token3 = token2.append(keypair4, message3.getBytes()); - assertEquals(Right(null), token3.verify(root.public_key())); - } - - @Test - public void testSerializeBiscuit() throws Error { - var root = KeyPair.generate(SECP256R1); - var biscuit = Biscuit.builder(root) - .add_authority_fact("user(\"1234\")") - .add_authority_check("check if operation(\"read\")") - .build(); - var serialized = biscuit.serialize(); - var unverified = Biscuit.from_bytes(serialized); - assertDoesNotThrow(() -> unverified.verify(root.public_key())); - } + @Test + public void testSerialize() { + prTestSerialize(Schema.PublicKey.Algorithm.Ed25519, 32); + prTestSerialize( + // compressed - 0x02 or 0x03 prefix byte, 32 bytes for X coordinate + Schema.PublicKey.Algorithm.SECP256R1, 33); + } + + @Test + public void testThreeMessages() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { + prTestThreeMessages(Schema.PublicKey.Algorithm.Ed25519); + prTestThreeMessages(Schema.PublicKey.Algorithm.SECP256R1); + } + + @Test + public void testSerializeBiscuit() throws Error { + var root = KeyPair.generate(Schema.PublicKey.Algorithm.SECP256R1); + var biscuit = + Biscuit.builder(root) + .addAuthorityFact("user(\"1234\")") + .addAuthorityCheck("check if operation(\"read\")") + .build(); + var serialized = biscuit.serialize(); + var unverified = Biscuit.fromBytes(serialized); + assertDoesNotThrow(() -> unverified.verify(root.getPublicKey())); + } + + private static void prTestSerialize( + Schema.PublicKey.Algorithm algorithm, int expectedPublicKeyLength) { + byte[] seed = {1, 2, 3, 4}; + SecureRandom rng = new SecureRandom(seed); + + KeyPair keypair = KeyPair.generate(algorithm, rng); + PublicKey pubkey = keypair.getPublicKey(); + + byte[] serializedSecretKey = keypair.toBytes(); + byte[] serializedPublicKey = pubkey.toBytes(); + + final KeyPair deserializedSecretKey = KeyPair.generate(algorithm, serializedSecretKey); + final PublicKey deserializedPublicKey = new PublicKey(algorithm, serializedPublicKey); + + assertEquals(32, serializedSecretKey.length); + assertEquals(expectedPublicKeyLength, serializedPublicKey.length); + + System.out.println(keypair.toHex()); + System.out.println(deserializedSecretKey.toHex()); + assertArrayEquals(keypair.toBytes(), deserializedSecretKey.toBytes()); + + System.out.println(pubkey.toHex()); + System.out.println(deserializedPublicKey.toHex()); + assertEquals(pubkey.toHex(), deserializedPublicKey.toHex()); + } + + private static void prTestThreeMessages(Schema.PublicKey.Algorithm algorithm) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); + + String message1 = "hello"; + KeyPair root = KeyPair.generate(algorithm, rng); + KeyPair keypair2 = KeyPair.generate(algorithm, rng); + System.out.println("root key: " + root.toHex()); + System.out.println("keypair2: " + keypair2.toHex()); + System.out.println("root key public: " + root.getPublicKey().toHex()); + System.out.println("keypair2 public: " + keypair2.getPublicKey().toHex()); + + Token token1 = new Token(root, message1.getBytes(), keypair2); + assertEquals(Right(null), token1.verify(root.getPublicKey())); + + String message2 = "world"; + KeyPair keypair3 = KeyPair.generate(algorithm, rng); + Token token2 = token1.append(keypair3, message2.getBytes()); + assertEquals(Right(null), token2.verify(root.getPublicKey())); + + String message3 = "!!"; + KeyPair keypair4 = KeyPair.generate(algorithm, rng); + Token token3 = token2.append(keypair4, message3.getBytes()); + assertEquals(Right(null), token3.verify(root.getPublicKey())); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java index 5e38166b..e89e425a 100644 --- a/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java +++ b/src/test/java/org/biscuitsec/biscuit/datalog/ExpressionTest.java @@ -1,141 +1,131 @@ package org.biscuitsec.biscuit.datalog; -import org.biscuitsec.biscuit.datalog.expressions.Expression; -import org.biscuitsec.biscuit.datalog.expressions.Op; -import org.biscuitsec.biscuit.error.Error; -import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.datalog.expressions.Op; +import org.biscuitsec.biscuit.error.Error; +import org.junit.jupiter.api.Test; public class ExpressionTest { - @Test - public void testNegate() throws Error.Execution { - SymbolTable symbols = new SymbolTable(); - symbols.add("a"); - symbols.add("b"); - symbols.add("var"); - - - Expression e = new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Integer(1)), - new Op.Value(new Term.Variable(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2)), - new Op.Binary(Op.BinaryOp.LessThan), - new Op.Unary(Op.UnaryOp.Negate) - ))); - - assertEquals( - "!1 < $var", - e.print(symbols).get() - ); - - HashMap variables = new HashMap<>(); - variables.put(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2L, new Term.Integer(0)); - - assertEquals( - new Term.Bool(true), - e.evaluate(variables, new TemporarySymbolTable(symbols)) - ); - } - - @Test - public void testAddsStr() throws Error.Execution { - SymbolTable symbols = new SymbolTable(); - symbols.add("a"); - symbols.add("b"); - symbols.add("ab"); - - - Expression e = new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), - new Op.Binary(Op.BinaryOp.Add) - ))); - - assertEquals( - "\"a\" + \"b\"", - e.print(symbols).get() - ); - - assertEquals( - new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2), - e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) - ); - } - - @Test - public void testContainsStr() throws Error.Execution { - SymbolTable symbols = new SymbolTable(); - symbols.add("ab"); - symbols.add("b"); - - - Expression e = new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), - new Op.Binary(Op.BinaryOp.Contains) - ))); - - assertEquals( - "\"ab\".contains(\"b\")", - e.print(symbols).get() - ); - - assertEquals( - new Term.Bool(true), - e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) - ); - } - - @Test - public void testNegativeContainsStr() throws Error.Execution { - SymbolTable symbols = new SymbolTable(); - symbols.add("ab"); - symbols.add("b"); - - - Expression e = new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), - new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), - new Op.Binary(Op.BinaryOp.Contains), - new Op.Unary(Op.UnaryOp.Negate) - ))); - - assertEquals( - "!\"ab\".contains(\"b\")", - e.print(symbols).get() - ); - - assertEquals( - new Term.Bool(false), - e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) - ); - } - - @Test - public void testIntersectionAndContains() throws Error.Execution { - SymbolTable symbols = new SymbolTable(); - - Expression e = new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3))))), - new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(1), new Term.Integer(2))))), - new Op.Binary(Op.BinaryOp.Intersection), - new Op.Value(new Term.Integer(1)), - new Op.Binary(Op.BinaryOp.Contains) - ))); - - assertEquals( - "[1, 2, 3].intersection([1, 2]).contains(1)", - e.print(symbols).get() - ); - - assertEquals( - new Term.Bool(true), - e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbols)) - ); - } + @Test + public void testNegate() throws Error.Execution { + SymbolTable symbolTable = new SymbolTable(); + symbolTable.add("a"); + symbolTable.add("b"); + symbolTable.add("var"); + + Expression e = + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Integer(1)), + new Op.Value(new Term.Variable(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2)), + new Op.Binary(Op.BinaryOp.LessThan), + new Op.Unary(Op.UnaryOp.Negate)))); + + assertEquals("!1 < $var", e.print(symbolTable).get()); + + HashMap variables = new HashMap<>(); + variables.put(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2L, new Term.Integer(0)); + + assertEquals(new Term.Bool(true), e.evaluate(variables, new TemporarySymbolTable(symbolTable))); + } + + @Test + public void testAddsStr() throws Error.Execution { + SymbolTable symbolTable = new SymbolTable(); + symbolTable.add("a"); + symbolTable.add("b"); + symbolTable.add("ab"); + + Expression e = + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), + new Op.Binary(Op.BinaryOp.Add)))); + + assertEquals("\"a\" + \"b\"", e.print(symbolTable).get()); + + assertEquals( + new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 2), + e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbolTable))); + } + + @Test + public void testContainsStr() throws Error.Execution { + SymbolTable symbolTable = new SymbolTable(); + symbolTable.add("ab"); + symbolTable.add("b"); + + Expression e = + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), + new Op.Binary(Op.BinaryOp.Contains)))); + + assertEquals("\"ab\".contains(\"b\")", e.print(symbolTable).get()); + + assertEquals( + new Term.Bool(true), e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbolTable))); + } + + @Test + public void testNegativeContainsStr() throws Error.Execution { + SymbolTable symbolTable = new SymbolTable(); + symbolTable.add("ab"); + symbolTable.add("b"); + + Expression e = + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET)), + new Op.Value(new Term.Str(SymbolTable.DEFAULT_SYMBOLS_OFFSET + 1)), + new Op.Binary(Op.BinaryOp.Contains), + new Op.Unary(Op.UnaryOp.Negate)))); + + assertEquals("!\"ab\".contains(\"b\")", e.print(symbolTable).get()); + + assertEquals( + new Term.Bool(false), e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbolTable))); + } + + @Test + public void testIntersectionAndContains() throws Error.Execution { + SymbolTable symbolTable = new SymbolTable(); + + Expression e = + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value( + new Term.Set( + new HashSet<>( + Arrays.asList( + new Term.Integer(1), + new Term.Integer(2), + new Term.Integer(3))))), + new Op.Value( + new Term.Set( + new HashSet<>( + Arrays.asList(new Term.Integer(1), new Term.Integer(2))))), + new Op.Binary(Op.BinaryOp.Intersection), + new Op.Value(new Term.Integer(1)), + new Op.Binary(Op.BinaryOp.Contains)))); + + assertEquals("[1, 2, 3].intersection([1, 2]).contains(1)", e.print(symbolTable).get()); + + assertEquals( + new Term.Bool(true), e.evaluate(new HashMap<>(), new TemporarySymbolTable(symbolTable))); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/datalog/WorldTest.java b/src/test/java/org/biscuitsec/biscuit/datalog/WorldTest.java index f0331d20..9f65c9f7 100644 --- a/src/test/java/org/biscuitsec/biscuit/datalog/WorldTest.java +++ b/src/test/java/org/biscuitsec/biscuit/datalog/WorldTest.java @@ -1,469 +1,766 @@ package org.biscuitsec.biscuit.datalog; -import org.biscuitsec.biscuit.datalog.expressions.Expression; -import org.biscuitsec.biscuit.datalog.expressions.Op; -import org.biscuitsec.biscuit.error.Error; -import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import java.time.Instant; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; import java.util.stream.Collectors; +import org.biscuitsec.biscuit.datalog.expressions.Expression; +import org.biscuitsec.biscuit.datalog.expressions.Op; +import org.biscuitsec.biscuit.error.Error; +import org.junit.jupiter.api.Test; public class WorldTest { - @Test - public void testFamily() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - final Term a = syms.add("A"); - final Term b = syms.add("B"); - final Term c = syms.add("C"); - final Term d = syms.add("D"); - final Term e = syms.add("E"); - final long parent = syms.insert("parent"); - final long grandparent = syms.insert("grandparent"); - final long sibling = syms.insert("siblings"); - - w.add_fact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(a, b)))); - w.add_fact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(b, c)))); - w.add_fact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(c, d)))); - - final Rule r1 = new Rule(new Predicate(grandparent, - Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("grandchild")))), Arrays.asList( - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("parent")))), - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("grandchild")))) - ), new ArrayList<>()); - - System.out.println("testing r1: " + syms.print_rule(r1)); - FactSet query_rule_result = w.query_rule(r1, (long)0, new TrustedOrigins(0), syms); - System.out.println("grandparents query_rules: [" + String.join(", ", query_rule_result.stream().map((f) -> syms.print_fact(f)).collect(Collectors.toList())) + "]"); - System.out.println("current facts: [" + String.join(", ", w.facts().stream().map((f) -> syms.print_fact(f)).collect(Collectors.toList())) + "]"); - - final Rule r2 = new Rule(new Predicate(grandparent, - Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("grandchild")))), Arrays.asList( - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("parent")))), - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("grandchild")))) - ), new ArrayList<>()); - - System.out.println("adding r2: " + syms.print_rule(r2)); - w.add_rule((long)0, new TrustedOrigins(0), r2); - w.run(syms); - - System.out.println("parents:"); - final Rule query1 = new Rule(new Predicate(parent, - Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("child")))), - Arrays.asList(new Predicate(parent, - Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("child"))))), - new ArrayList<>()); - - for (Iterator it = w.query_rule(query1, (long) 0, new TrustedOrigins(0), syms).stream().iterator(); it.hasNext(); ) { - Fact fact = it.next(); - System.out.println("\t" + syms.print_fact(fact)); - } - final Rule query2 = new Rule(new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), b)), - Arrays.asList(new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), b))), - new ArrayList<>()); - System.out.println("parents of B: [" + String.join(", ", - w.query_rule(query2, (long) 0, new TrustedOrigins(0), syms) - .stream().map((f) -> syms.print_fact(f)).collect(Collectors.toSet())) + "]"); - final Rule query3 = new Rule(new Predicate(grandparent, Arrays.asList(new Term.Variable(syms.insert("grandparent")), - new Term.Variable(syms.insert("grandchild")))), - Arrays.asList(new Predicate(grandparent, Arrays.asList(new Term.Variable(syms.insert("grandparent")), - new Term.Variable(syms.insert("grandchild"))))), - new ArrayList<>()); - System.out.println("grandparents: [" + String.join(", ", - w.query_rule(query3, (long) 0, new TrustedOrigins(0), syms) - .stream().map((f) -> syms.print_fact(f)).collect(Collectors.toSet())) + "]"); - - w.add_fact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(c, e)))); - w.run(syms); - - final Rule query4 = new Rule(new Predicate(grandparent, - Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("grandchild")))), - Arrays.asList(new Predicate(grandparent, - Arrays.asList(new Term.Variable(syms.insert("grandparent")), new Term.Variable(syms.insert("grandchild"))))), - new ArrayList<>()); - final FactSet res = w.query_rule(query4, (long) 0, new TrustedOrigins(0), syms); - System.out.println("grandparents after inserting parent(C, E): [" + String.join(", ", - res.stream().map((f) -> syms.print_fact(f)).collect(Collectors.toSet())) + "]"); - - final FactSet expected = new FactSet(new Origin(0), new HashSet<>(Arrays.asList( - new Fact(new Predicate(grandparent, Arrays.asList(a, c))), - new Fact(new Predicate(grandparent, Arrays.asList(b, d))), - new Fact(new Predicate(grandparent, Arrays.asList(b, e)))))); - assertEquals(expected, res); - - w.add_rule((long) 0, new TrustedOrigins(0), new Rule(new Predicate(sibling, - Arrays.asList(new Term.Variable(syms.insert("sibling1")), new Term.Variable(syms.insert("sibling2")))), Arrays.asList( - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("sibling1")))), - new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), new Term.Variable(syms.insert("sibling2")))) - ), new ArrayList<>())); - w.run(syms); - - final Rule query5 = new Rule(new Predicate(sibling, Arrays.asList( - new Term.Variable(syms.insert("sibling1")), - new Term.Variable(syms.insert("sibling2")))), - Arrays.asList(new Predicate(sibling, Arrays.asList( - new Term.Variable(syms.insert("sibling1")), - new Term.Variable(syms.insert("sibling2"))))), - new ArrayList<>()); - System.out.println("siblings: [" + String.join(", ", - w.query_rule(query5, (long) 0, new TrustedOrigins(0), syms) - .stream().map((f) -> syms.print_fact(f)).collect(Collectors.toSet())) + "]"); - } - - @Test - public void testNumbers() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - - final Term abc = syms.add("abc"); - final Term def = syms.add("def"); - final Term ghi = syms.add("ghi"); - final Term jkl = syms.add("jkl"); - final Term mno = syms.add("mno"); - final Term aaa = syms.add("AAA"); - final Term bbb = syms.add("BBB"); - final Term ccc = syms.add("CCC"); - final long t1 = syms.insert("t1"); - final long t2 = syms.insert("t2"); - final long join = syms.insert("join"); - - w.add_fact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(0), abc)))); - w.add_fact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(1), def)))); - w.add_fact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(2), ghi)))); - w.add_fact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(3), jkl)))); - w.add_fact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(4), mno)))); - - w.add_fact(new Origin(0), new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(0), aaa, new Term.Integer(0))))); - w.add_fact(new Origin(0), new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(1), bbb, new Term.Integer(0))))); - w.add_fact(new Origin(0), new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(2), ccc, new Term.Integer(1))))); - - FactSet res = w.query_rule(new Rule(new Predicate(join, - Arrays.asList(new Term.Variable(syms.insert("left")), new Term.Variable(syms.insert("right"))) - ), - Arrays.asList(new Predicate(t1, Arrays.asList(new Term.Variable(syms.insert("id")), new Term.Variable(syms.insert("left")))), - new Predicate(t2, + @Test + public void testFamily() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + final Term a = syms.add("A"); + final Term b = syms.add("B"); + final Term c = syms.add("C"); + final Term d = syms.add("D"); + final Term e = syms.add("E"); + final long parent = syms.insert("parent"); + final long grandparent = syms.insert("grandparent"); + final long sibling = syms.insert("siblings"); + + w.addFact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(a, b)))); + w.addFact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(b, c)))); + w.addFact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(c, d)))); + + final Rule r1 = + new Rule( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild")))), + Arrays.asList( + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("parent")))), + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("grandchild"))))), + new ArrayList<>()); + + System.out.println("testing r1: " + syms.formatRule(r1)); + FactSet queryRuleResult = w.queryRule(r1, (long) 0, new TrustedOrigins(0), syms); + System.out.println( + "grandparents query_rules: [" + + String.join( + ", ", + queryRuleResult.stream() + .map((f) -> syms.formatFact(f)) + .collect(Collectors.toList())) + + "]"); + System.out.println( + "current facts: [" + + String.join( + ", ", + w.getFacts().stream().map((f) -> syms.formatFact(f)).collect(Collectors.toList())) + + "]"); + + final Rule r2 = + new Rule( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild")))), + Arrays.asList( + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("parent")))), + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("grandchild"))))), + new ArrayList<>()); + + System.out.println("adding r2: " + syms.formatRule(r2)); + w.addRule((long) 0, new TrustedOrigins(0), r2); + w.run(syms); + + System.out.println("parents:"); + final Rule query1 = + new Rule( + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("child")))), + Arrays.asList( + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("child"))))), + new ArrayList<>()); + + for (Iterator it = + w.queryRule(query1, (long) 0, new TrustedOrigins(0), syms).stream().iterator(); + it.hasNext(); ) { + Fact fact = it.next(); + System.out.println("\t" + syms.formatFact(fact)); + } + final Rule query2 = + new Rule( + new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), b)), + Arrays.asList( + new Predicate(parent, Arrays.asList(new Term.Variable(syms.insert("parent")), b))), + new ArrayList<>()); + System.out.println( + "parents of B: [" + + String.join( + ", ", + w.queryRule(query2, (long) 0, new TrustedOrigins(0), syms).stream() + .map((f) -> syms.formatFact(f)) + .collect(Collectors.toSet())) + + "]"); + final Rule query3 = + new Rule( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild")))), + Arrays.asList( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild"))))), + new ArrayList<>()); + System.out.println( + "grandparents: [" + + String.join( + ", ", + w.queryRule(query3, (long) 0, new TrustedOrigins(0), syms).stream() + .map((f) -> syms.formatFact(f)) + .collect(Collectors.toSet())) + + "]"); + + w.addFact(new Origin(0), new Fact(new Predicate(parent, Arrays.asList(c, e)))); + w.run(syms); + + final Rule query4 = + new Rule( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild")))), + Arrays.asList( + new Predicate( + grandparent, + Arrays.asList( + new Term.Variable(syms.insert("grandparent")), + new Term.Variable(syms.insert("grandchild"))))), + new ArrayList<>()); + final FactSet res = w.queryRule(query4, (long) 0, new TrustedOrigins(0), syms); + System.out.println( + "grandparents after inserting parent(C, E): [" + + String.join( + ", ", res.stream().map((f) -> syms.formatFact(f)).collect(Collectors.toSet())) + + "]"); + + final FactSet expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(grandparent, Arrays.asList(a, c))), + new Fact(new Predicate(grandparent, Arrays.asList(b, d))), + new Fact(new Predicate(grandparent, Arrays.asList(b, e)))))); + assertEquals(expected, res); + + w.addRule( + (long) 0, + new TrustedOrigins(0), + new Rule( + new Predicate( + sibling, + Arrays.asList( + new Term.Variable(syms.insert("sibling1")), + new Term.Variable(syms.insert("sibling2")))), + Arrays.asList( + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("sibling1")))), + new Predicate( + parent, + Arrays.asList( + new Term.Variable(syms.insert("parent")), + new Term.Variable(syms.insert("sibling2"))))), + new ArrayList<>())); + w.run(syms); + + final Rule query5 = + new Rule( + new Predicate( + sibling, + Arrays.asList( + new Term.Variable(syms.insert("sibling1")), + new Term.Variable(syms.insert("sibling2")))), + Arrays.asList( + new Predicate( + sibling, + Arrays.asList( + new Term.Variable(syms.insert("sibling1")), + new Term.Variable(syms.insert("sibling2"))))), + new ArrayList<>()); + System.out.println( + "siblings: [" + + String.join( + ", ", + w.queryRule(query5, (long) 0, new TrustedOrigins(0), syms).stream() + .map((f) -> syms.formatFact(f)) + .collect(Collectors.toSet())) + + "]"); + } + + @Test + public void testNumbers() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + + final Term abc = syms.add("abc"); + final Term def = syms.add("def"); + final Term ghi = syms.add("ghi"); + final Term jkl = syms.add("jkl"); + final Term mno = syms.add("mno"); + final Term aaa = syms.add("AAA"); + final Term bbb = syms.add("BBB"); + final Term ccc = syms.add("CCC"); + final long t1 = syms.insert("t1"); + final long t2 = syms.insert("t2"); + final long join = syms.insert("join"); + + w.addFact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(0), abc)))); + w.addFact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(1), def)))); + w.addFact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(2), ghi)))); + w.addFact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(3), jkl)))); + w.addFact(new Origin(0), new Fact(new Predicate(t1, Arrays.asList(new Term.Integer(4), mno)))); + + w.addFact( + new Origin(0), + new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(0), aaa, new Term.Integer(0))))); + w.addFact( + new Origin(0), + new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(1), bbb, new Term.Integer(0))))); + w.addFact( + new Origin(0), + new Fact(new Predicate(t2, Arrays.asList(new Term.Integer(2), ccc, new Term.Integer(1))))); + + FactSet res = + w.queryRule( + new Rule( + new Predicate( + join, + Arrays.asList( + new Term.Variable(syms.insert("left")), + new Term.Variable(syms.insert("right")))), + Arrays.asList( + new Predicate( + t1, + Arrays.asList( + new Term.Variable(syms.insert("id")), + new Term.Variable(syms.insert("left")))), + new Predicate( + t2, + Arrays.asList( + new Term.Variable(syms.insert("t2_id")), + new Term.Variable(syms.insert("right")), + new Term.Variable(syms.insert("id"))))), + new ArrayList<>()), + (long) 0, + new TrustedOrigins(0), + syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + FactSet expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(join, Arrays.asList(abc, aaa))), + new Fact(new Predicate(join, Arrays.asList(abc, bbb))), + new Fact(new Predicate(join, Arrays.asList(def, ccc)))))); + assertEquals(expected, res); + + res = + w.queryRule( + new Rule( + new Predicate( + join, + Arrays.asList( + new Term.Variable(syms.insert("left")), + new Term.Variable(syms.insert("right")))), + Arrays.asList( + new Predicate( + t1, + Arrays.asList( + new Term.Variable(syms.insert("id")), + new Term.Variable(syms.insert("left")))), + new Predicate( + t2, + Arrays.asList( + new Term.Variable(syms.insert("t2_id")), + new Term.Variable(syms.insert("right")), + new Term.Variable(syms.insert("id"))))), + Arrays.asList( + new Expression( + new ArrayList( Arrays.asList( - new Term.Variable(syms.insert("t2_id")), - new Term.Variable(syms.insert("right")), - new Term.Variable(syms.insert("id"))))), new ArrayList<>()), - (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - FactSet expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(join, Arrays.asList(abc, aaa))), - new Fact(new Predicate(join, Arrays.asList(abc, bbb))), - new Fact(new Predicate(join, Arrays.asList(def, ccc)))))); - assertEquals(expected, res); - - res = w.query_rule(new Rule(new Predicate(join, - Arrays.asList(new Term.Variable(syms.insert("left")), new Term.Variable(syms.insert("right")))), - Arrays.asList(new Predicate(t1, Arrays.asList(new Term.Variable(syms.insert("id")), new Term.Variable(syms.insert("left")))), - new Predicate(t2, - Arrays.asList( - new Term.Variable(syms.insert("t2_id")), - new Term.Variable(syms.insert("right")), - new Term.Variable(syms.insert("id"))))), - Arrays.asList(new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("id"))), - new Op.Value(new Term.Integer(1)), - new Op.Binary(Op.BinaryOp.LessThan) - )))) - ), (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - expected = new FactSet(new Origin(0), - new HashSet<>(Arrays.asList(new Fact(new Predicate(join, Arrays.asList(abc, aaa))), new Fact(new Predicate(join, Arrays.asList(abc, bbb)))))); - assertEquals(expected, res); - } - - private final FactSet testSuffix(final World w, SymbolTable syms, final long suff, final long route, final String suffix) throws Error { - return w.query_rule(new Rule(new Predicate(suff, - Arrays.asList(new Term.Variable(syms.insert("app_id")), new Term.Variable(syms.insert("domain")))), - Arrays.asList( - new Predicate(route, Arrays.asList( - new Term.Variable(syms.insert("route_id")), + new Op.Value(new Term.Variable(syms.insert("id"))), + new Op.Value(new Term.Integer(1)), + new Op.Binary(Op.BinaryOp.LessThan)))))), + (long) 0, + new TrustedOrigins(0), + syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(join, Arrays.asList(abc, aaa))), + new Fact(new Predicate(join, Arrays.asList(abc, bbb)))))); + assertEquals(expected, res); + } + + private final FactSet testSuffix( + final World w, SymbolTable syms, final long suff, final long route, final String suffix) + throws Error { + return w.queryRule( + new Rule( + new Predicate( + suff, + Arrays.asList( new Term.Variable(syms.insert("app_id")), - new Term.Variable(syms.insert("domain")))) - ), - Arrays.asList(new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("domain"))), - new Op.Value(syms.add(suffix)), - new Op.Binary(Op.BinaryOp.Suffix) - )))) - ), (long) 0, new TrustedOrigins(0), syms); - } - - @Test - public void testStr() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - - final Term app_0 = syms.add("app_0"); - final Term app_1 = syms.add("app_1"); - final Term app_2 = syms.add("app_2"); - final long route = syms.insert("route"); - final long suff = syms.insert("route suffix"); - - w.add_fact(new Origin(0), new Fact(new Predicate(route, Arrays.asList(new Term.Integer(0), app_0, syms.add("example.com"))))); - w.add_fact(new Origin(0), new Fact(new Predicate(route, Arrays.asList(new Term.Integer(1), app_1, syms.add("test.com"))))); - w.add_fact(new Origin(0), new Fact(new Predicate(route, Arrays.asList(new Term.Integer(2), app_2, syms.add("test.fr"))))); - w.add_fact(new Origin(0), new Fact(new Predicate(route, Arrays.asList(new Term.Integer(3), app_0, syms.add("www.example.com"))))); - w.add_fact(new Origin(0), new Fact(new Predicate(route, Arrays.asList(new Term.Integer(4), app_1, syms.add("mx.example.com"))))); - - FactSet res = testSuffix(w, syms, suff, route, ".fr"); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - FactSet expected = new FactSet(new Origin(0), - new HashSet<>(Arrays.asList(new Fact(new Predicate(suff, Arrays.asList(app_2, syms.add("test.fr"))))))); - assertEquals(expected, res); - - res = testSuffix(w, syms, suff, route, "example.com"); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(suff, - Arrays.asList( - app_0, - syms.add("example.com")))), - new Fact(new Predicate(suff, - Arrays.asList(app_0, syms.add("www.example.com")))), - new Fact(new Predicate(suff, Arrays.asList(app_1, syms.add("mx.example.com"))))))); - assertEquals(expected, res); - } - - @Test - public void testDate() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - - final Instant t1 = Instant.now(); - System.out.println("t1 = " + t1); - final Instant t2 = t1.plusSeconds(10); - System.out.println("t2 = " + t2); - final Instant t3 = t2.plusSeconds(30); - System.out.println("t3 = " + t3); - - final long t2_timestamp = t2.getEpochSecond(); - - final Term abc = syms.add("abc"); - final Term def = syms.add("def"); - final long x = syms.insert("x"); - final long before = syms.insert("before"); - final long after = syms.insert("after"); - - w.add_fact(new Origin(0), new Fact(new Predicate(x, Arrays.asList(new Term.Date(t1.getEpochSecond()), abc)))); - w.add_fact(new Origin(0), new Fact(new Predicate(x, Arrays.asList(new Term.Date(t3.getEpochSecond()), def)))); - - final Rule r1 = new Rule(new Predicate( - before, - Arrays.asList(new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))), - Arrays.asList( - new Predicate(x, Arrays.asList(new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))) - ), - Arrays.asList( - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("date"))), - new Op.Value(new Term.Date(t2_timestamp)), - new Op.Binary(Op.BinaryOp.LessOrEqual) - ))), - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("date"))), - new Op.Value(new Term.Date(0)), - new Op.Binary(Op.BinaryOp.GreaterOrEqual) - ))) - ) - ); - - System.out.println("testing r1: " + syms.print_rule(r1)); - FactSet res = w.query_rule(r1, (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - FactSet expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(before, Arrays.asList(new Term.Date(t1.getEpochSecond()), abc)))))); - assertEquals(expected, res); - - final Rule r2 = new Rule(new Predicate( - after, - Arrays.asList(new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))), - Arrays.asList( - new Predicate(x, Arrays.asList(new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))) - ), - Arrays.asList( - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("date"))), - new Op.Value(new Term.Date(t2_timestamp)), - new Op.Binary(Op.BinaryOp.GreaterOrEqual) - ))), - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Variable(syms.insert("date"))), - new Op.Value(new Term.Date(0)), - new Op.Binary(Op.BinaryOp.GreaterOrEqual) - ))) - ) - ); - - System.out.println("testing r2: " + syms.print_rule(r2)); - res = w.query_rule(r2, (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(after, Arrays.asList(new Term.Date(t3.getEpochSecond()), def)))))); - assertEquals(expected, res); - } - - @Test - public void testSet() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - - final Term abc = syms.add("abc"); - final Term def = syms.add("def"); - final long x = syms.insert("x"); - final long int_set = syms.insert("int_set"); - final long symbol_set = syms.insert("symbol_set"); - final long string_set = syms.insert("string_set"); - - w.add_fact(new Origin(0), new Fact(new Predicate(x, Arrays.asList(abc, new Term.Integer(0), syms.add("test"))))); - w.add_fact(new Origin(0), new Fact(new Predicate(x, Arrays.asList(def, new Term.Integer(2), syms.add("hello"))))); - - final Rule r1 = new Rule(new Predicate( - int_set, - Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("str"))) - ), - Arrays.asList(new Predicate(x, - Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("int")), new Term.Variable(syms.insert("str")))) - ), - Arrays.asList( - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Integer(0l), new Term.Integer(1l))))), - new Op.Value(new Term.Variable(syms.insert("int"))), - new Op.Binary(Op.BinaryOp.Contains) - ))) - ) - ); - System.out.println("testing r1: " + syms.print_rule(r1)); - FactSet res = w.query_rule(r1, (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - FactSet expected = new FactSet(new Origin(0), new HashSet<>(Arrays.asList(new Fact(new Predicate(int_set, Arrays.asList(abc, syms.add("test"))))))); - assertEquals(expected, res); - - final long abc_sym_id = syms.insert("abc"); - final long ghi_sym_id = syms.insert("ghi"); - - final Rule r2 = new Rule(new Predicate(symbol_set, - Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("int")), new Term.Variable(syms.insert("str")))), - Arrays.asList(new Predicate(x, Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("int")), new Term.Variable(syms.insert("str")))) - ), - Arrays.asList( - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(new Term.Str(abc_sym_id), new Term.Str(ghi_sym_id))))), - new Op.Value(new Term.Variable(syms.insert("sym"))), - new Op.Binary(Op.BinaryOp.Contains), - new Op.Unary(Op.UnaryOp.Negate) - ))) - ) - ); - - System.out.println("testing r2: " + syms.print_rule(r2)); - res = w.query_rule(r2, (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(symbol_set, Arrays.asList(def, new Term.Integer(2), syms.add("hello"))))))); - assertEquals(expected, res); - - final Rule r3 = new Rule( - new Predicate(string_set, Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("int")), new Term.Variable(syms.insert("str")))), - Arrays.asList(new Predicate(x, Arrays.asList(new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("int")), new Term.Variable(syms.insert("str"))))), - Arrays.asList( - new Expression(new ArrayList(Arrays.asList( - new Op.Value(new Term.Set(new HashSet<>(Arrays.asList(syms.add("test"), syms.add("aaa"))))), - new Op.Value(new Term.Variable(syms.insert("str"))), - new Op.Binary(Op.BinaryOp.Contains) - ))) - ) - ); - System.out.println("testing r3: " + syms.print_rule(r3)); - res = w.query_rule(r3, (long) 0, new TrustedOrigins(0), syms); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - expected = new FactSet(new Origin(0),new HashSet<>(Arrays.asList(new Fact(new Predicate(string_set, Arrays.asList(abc, new Term.Integer(0), syms.add("test"))))))); - assertEquals(expected, res); - } - - @Test - public void testResource() throws Error { - final World w = new World(); - final SymbolTable syms = new SymbolTable(); - - final Term authority = syms.add("authority"); - final Term ambient = syms.add("ambient"); - final long resource = syms.insert("resource"); - final long operation = syms.insert("operation"); - final long right = syms.insert("right"); - final Term file1 = syms.add("file1"); - final Term file2 = syms.add("file2"); - final Term read = syms.add("read"); - final Term write = syms.add("write"); - - - w.add_fact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file1, read)))); - w.add_fact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file2, read)))); - w.add_fact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file1, write)))); - - final long caveat1 = syms.insert("caveat1"); - //r1: caveat2(#file1) <- resource(#ambient, #file1) - final Rule r1 = new Rule( - new Predicate(caveat1, Arrays.asList(file1)), - Arrays.asList(new Predicate(resource, Arrays.asList(file1)) - ), new ArrayList<>()); - - System.out.println("testing caveat 1(should return nothing): " + syms.print_rule(r1)); - FactSet res = w.query_rule(r1, (long) 0, new TrustedOrigins(0), syms); - System.out.println(res); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - assertTrue(res.size() == 0); - - final long caveat2 = syms.insert("caveat2"); - final long var0_id = syms.insert("var0"); - final Term var0 = new Term.Variable(var0_id); - //r2: caveat1(0?) <- resource(#ambient, 0?) && operation(#ambient, #read) && right(#authority, 0?, #read) - final Rule r2 = new Rule( - new Predicate(caveat2, Arrays.asList(var0)), - Arrays.asList( - new Predicate(resource, Arrays.asList(var0)), - new Predicate(operation, Arrays.asList(read)), - new Predicate(right, Arrays.asList(var0, read)) - ), new ArrayList<>()); - - System.out.println("testing caveat 2: " + syms.print_rule(r2)); - res = w.query_rule(r2, (long) 0, new TrustedOrigins(0), syms); - System.out.println(res); - for (Iterator it = res.stream().iterator(); it.hasNext(); ) { - Fact f = it.next(); - System.out.println("\t" + syms.print_fact(f)); - } - assertTrue(res.size() == 0); - } + new Term.Variable(syms.insert("domain")))), + Arrays.asList( + new Predicate( + route, + Arrays.asList( + new Term.Variable(syms.insert("route_id")), + new Term.Variable(syms.insert("app_id")), + new Term.Variable(syms.insert("domain"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Variable(syms.insert("domain"))), + new Op.Value(syms.add(suffix)), + new Op.Binary(Op.BinaryOp.Suffix)))))), + (long) 0, + new TrustedOrigins(0), + syms); + } + + @Test + public void testStr() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + + final Term app_0 = syms.add("app_0"); + final Term app_1 = syms.add("app_1"); + final Term app_2 = syms.add("app_2"); + final long route = syms.insert("route"); + final long suff = syms.insert("route suffix"); + + w.addFact( + new Origin(0), + new Fact( + new Predicate( + route, Arrays.asList(new Term.Integer(0), app_0, syms.add("example.com"))))); + w.addFact( + new Origin(0), + new Fact( + new Predicate(route, Arrays.asList(new Term.Integer(1), app_1, syms.add("test.com"))))); + w.addFact( + new Origin(0), + new Fact( + new Predicate(route, Arrays.asList(new Term.Integer(2), app_2, syms.add("test.fr"))))); + w.addFact( + new Origin(0), + new Fact( + new Predicate( + route, Arrays.asList(new Term.Integer(3), app_0, syms.add("www.example.com"))))); + w.addFact( + new Origin(0), + new Fact( + new Predicate( + route, Arrays.asList(new Term.Integer(4), app_1, syms.add("mx.example.com"))))); + + FactSet res = testSuffix(w, syms, suff, route, ".fr"); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + FactSet expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(suff, Arrays.asList(app_2, syms.add("test.fr"))))))); + assertEquals(expected, res); + + res = testSuffix(w, syms, suff, route, "example.com"); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(suff, Arrays.asList(app_0, syms.add("example.com")))), + new Fact( + new Predicate(suff, Arrays.asList(app_0, syms.add("www.example.com")))), + new Fact( + new Predicate(suff, Arrays.asList(app_1, syms.add("mx.example.com"))))))); + assertEquals(expected, res); + } + + @Test + public void testDate() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + + final Instant t1 = Instant.now(); + System.out.println("t1 = " + t1); + final Instant t2 = t1.plusSeconds(10); + System.out.println("t2 = " + t2); + final Instant t3 = t2.plusSeconds(30); + System.out.println("t3 = " + t3); + + final long t2_timestamp = t2.getEpochSecond(); + + final Term abc = syms.add("abc"); + final Term def = syms.add("def"); + final long x = syms.insert("x"); + final long before = syms.insert("before"); + final long after = syms.insert("after"); + + w.addFact( + new Origin(0), + new Fact(new Predicate(x, Arrays.asList(new Term.Date(t1.getEpochSecond()), abc)))); + w.addFact( + new Origin(0), + new Fact(new Predicate(x, Arrays.asList(new Term.Date(t3.getEpochSecond()), def)))); + + final Rule r1 = + new Rule( + new Predicate( + before, + Arrays.asList( + new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))), + Arrays.asList( + new Predicate( + x, + Arrays.asList( + new Term.Variable(syms.insert("date")), + new Term.Variable(syms.insert("val"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Variable(syms.insert("date"))), + new Op.Value(new Term.Date(t2_timestamp)), + new Op.Binary(Op.BinaryOp.LessOrEqual)))), + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Variable(syms.insert("date"))), + new Op.Value(new Term.Date(0)), + new Op.Binary(Op.BinaryOp.GreaterOrEqual)))))); + + System.out.println("testing r1: " + syms.formatRule(r1)); + FactSet res = w.queryRule(r1, (long) 0, new TrustedOrigins(0), syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + FactSet expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact( + new Predicate( + before, Arrays.asList(new Term.Date(t1.getEpochSecond()), abc)))))); + assertEquals(expected, res); + + final Rule r2 = + new Rule( + new Predicate( + after, + Arrays.asList( + new Term.Variable(syms.insert("date")), new Term.Variable(syms.insert("val")))), + Arrays.asList( + new Predicate( + x, + Arrays.asList( + new Term.Variable(syms.insert("date")), + new Term.Variable(syms.insert("val"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Variable(syms.insert("date"))), + new Op.Value(new Term.Date(t2_timestamp)), + new Op.Binary(Op.BinaryOp.GreaterOrEqual)))), + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value(new Term.Variable(syms.insert("date"))), + new Op.Value(new Term.Date(0)), + new Op.Binary(Op.BinaryOp.GreaterOrEqual)))))); + + System.out.println("testing r2: " + syms.formatRule(r2)); + res = w.queryRule(r2, (long) 0, new TrustedOrigins(0), syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact( + new Predicate( + after, Arrays.asList(new Term.Date(t3.getEpochSecond()), def)))))); + assertEquals(expected, res); + } + + @Test + public void testSet() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + + final Term abc = syms.add("abc"); + final Term def = syms.add("def"); + final long x = syms.insert("x"); + final long int_set = syms.insert("int_set"); + final long symbol_set = syms.insert("symbol_set"); + final long string_set = syms.insert("string_set"); + + w.addFact( + new Origin(0), + new Fact(new Predicate(x, Arrays.asList(abc, new Term.Integer(0), syms.add("test"))))); + w.addFact( + new Origin(0), + new Fact(new Predicate(x, Arrays.asList(def, new Term.Integer(2), syms.add("hello"))))); + + final Rule r1 = + new Rule( + new Predicate( + int_set, + Arrays.asList( + new Term.Variable(syms.insert("sym")), new Term.Variable(syms.insert("str")))), + Arrays.asList( + new Predicate( + x, + Arrays.asList( + new Term.Variable(syms.insert("sym")), + new Term.Variable(syms.insert("int")), + new Term.Variable(syms.insert("str"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value( + new Term.Set( + new HashSet<>( + Arrays.asList( + new Term.Integer(0L), new Term.Integer(1L))))), + new Op.Value(new Term.Variable(syms.insert("int"))), + new Op.Binary(Op.BinaryOp.Contains)))))); + System.out.println("testing r1: " + syms.formatRule(r1)); + FactSet res = w.queryRule(r1, (long) 0, new TrustedOrigins(0), syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + FactSet expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact(new Predicate(int_set, Arrays.asList(abc, syms.add("test"))))))); + assertEquals(expected, res); + + final long abcSymId = syms.insert("abc"); + final long ghiSymId = syms.insert("ghi"); + + final Rule r2 = + new Rule( + new Predicate( + symbol_set, + Arrays.asList( + new Term.Variable(syms.insert("sym")), + new Term.Variable(syms.insert("int")), + new Term.Variable(syms.insert("str")))), + Arrays.asList( + new Predicate( + x, + Arrays.asList( + new Term.Variable(syms.insert("sym")), + new Term.Variable(syms.insert("int")), + new Term.Variable(syms.insert("str"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value( + new Term.Set( + new HashSet<>( + Arrays.asList( + new Term.Str(abcSymId), new Term.Str(ghiSymId))))), + new Op.Value(new Term.Variable(syms.insert("sym"))), + new Op.Binary(Op.BinaryOp.Contains), + new Op.Unary(Op.UnaryOp.Negate)))))); + + System.out.println("testing r2: " + syms.formatRule(r2)); + res = w.queryRule(r2, (long) 0, new TrustedOrigins(0), syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact( + new Predicate( + symbol_set, + Arrays.asList(def, new Term.Integer(2), syms.add("hello"))))))); + assertEquals(expected, res); + + final Rule r3 = + new Rule( + new Predicate( + string_set, + Arrays.asList( + new Term.Variable(syms.insert("sym")), + new Term.Variable(syms.insert("int")), + new Term.Variable(syms.insert("str")))), + Arrays.asList( + new Predicate( + x, + Arrays.asList( + new Term.Variable(syms.insert("sym")), + new Term.Variable(syms.insert("int")), + new Term.Variable(syms.insert("str"))))), + Arrays.asList( + new Expression( + new ArrayList( + Arrays.asList( + new Op.Value( + new Term.Set( + new HashSet<>( + Arrays.asList(syms.add("test"), syms.add("aaa"))))), + new Op.Value(new Term.Variable(syms.insert("str"))), + new Op.Binary(Op.BinaryOp.Contains)))))); + System.out.println("testing r3: " + syms.formatRule(r3)); + res = w.queryRule(r3, (long) 0, new TrustedOrigins(0), syms); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + expected = + new FactSet( + new Origin(0), + new HashSet<>( + Arrays.asList( + new Fact( + new Predicate( + string_set, + Arrays.asList(abc, new Term.Integer(0), syms.add("test"))))))); + assertEquals(expected, res); + } + + @Test + public void testResource() throws Error { + final World w = new World(); + final SymbolTable syms = new SymbolTable(); + + final Term authority = syms.add("authority"); + final Term ambient = syms.add("ambient"); + final long resource = syms.insert("resource"); + final long operation = syms.insert("operation"); + final long right = syms.insert("right"); + final Term file1 = syms.add("file1"); + final Term file2 = syms.add("file2"); + final Term read = syms.add("read"); + final Term write = syms.add("write"); + + w.addFact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file1, read)))); + w.addFact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file2, read)))); + w.addFact(new Origin(0), new Fact(new Predicate(right, Arrays.asList(file1, write)))); + + final long caveat1 = syms.insert("caveat1"); + // r1: caveat2(#file1) <- resource(#ambient, #file1) + final Rule r1 = + new Rule( + new Predicate(caveat1, Arrays.asList(file1)), + Arrays.asList(new Predicate(resource, Arrays.asList(file1))), + new ArrayList<>()); + + System.out.println("testing caveat 1(should return nothing): " + syms.formatRule(r1)); + FactSet res = w.queryRule(r1, (long) 0, new TrustedOrigins(0), syms); + System.out.println(res); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + assertTrue(res.size() == 0); + + final long caveat2 = syms.insert("caveat2"); + final long var0_id = syms.insert("var0"); + final Term var0 = new Term.Variable(var0_id); + // r2: caveat1(0?) <- resource(#ambient, 0?) && operation(#ambient, #read) && right(#authority, + // 0?, #read) + final Rule r2 = + new Rule( + new Predicate(caveat2, Arrays.asList(var0)), + Arrays.asList( + new Predicate(resource, Arrays.asList(var0)), + new Predicate(operation, Arrays.asList(read)), + new Predicate(right, Arrays.asList(var0, read))), + new ArrayList<>()); + + System.out.println("testing caveat 2: " + syms.formatRule(r2)); + res = w.queryRule(r2, (long) 0, new TrustedOrigins(0), syms); + System.out.println(res); + for (Iterator it = res.stream().iterator(); it.hasNext(); ) { + Fact f = it.next(); + System.out.println("\t" + syms.formatFact(f)); + } + assertTrue(res.size() == 0); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/AuthorizerTest.java b/src/test/java/org/biscuitsec/biscuit/token/AuthorizerTest.java index 02053d84..5f57a096 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/AuthorizerTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/AuthorizerTest.java @@ -1,6 +1,14 @@ package org.biscuitsec.biscuit.token; +import static org.biscuitsec.biscuit.token.builder.Utils.constrainedRule; +import static org.junit.jupiter.api.Assertions.assertEquals; + import biscuit.format.schema.Schema; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.error.Error.Parser; @@ -8,76 +16,64 @@ import org.biscuitsec.biscuit.token.builder.Term; import org.junit.jupiter.api.Test; -import java.security.SecureRandom; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Set; - -import static org.biscuitsec.biscuit.token.builder.Utils.constrained_rule; -import static org.junit.jupiter.api.Assertions.assertEquals; - public class AuthorizerTest { - @Test - public void testAuthorizerPolicy() throws Parser { - Authorizer authorizer = new Authorizer(); - List policies = authorizer.policies; - authorizer.deny(); - assertEquals(1, policies.size()); - - authorizer.add_policy(new Policy( - Arrays.asList( - constrained_rule( - "deny", - new ArrayList<>(), - new ArrayList<>(), - Arrays.asList(new Expression.Value(new Term.Bool(true))) - ) - ), Policy.Kind.Deny)); - assertEquals(2, policies.size()); - - authorizer.add_policy("deny if true"); - assertEquals(3, policies.size()); - } - - - @Test - public void testPuttingSomeFactsInABiscuitAndGettingThemBackOutAgain() throws Exception { - - KeyPair keypair = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, new SecureRandom()); - - Biscuit token = Biscuit.builder(keypair) - .add_authority_fact("email(\"bob@example.com\")") - .add_authority_fact("id(123)") - .add_authority_fact("enabled(true)") - .add_authority_fact("perms([1,2,3])") - .build(); - - Authorizer authorizer = Biscuit.from_b64url(token.serialize_b64url(), keypair.public_key()) - .verify(keypair.public_key()) - .authorizer(); - - Term emailTerm = queryFirstResult(authorizer, "emailfact($name) <- email($name)"); - assertEquals("bob@example.com", ((Term.Str) emailTerm).getValue()); - - Term idTerm = queryFirstResult(authorizer, "idfact($name) <- id($name)"); - assertEquals(123, ((Term.Integer) idTerm).getValue()); - - Term enabledTerm = queryFirstResult(authorizer, "enabledfact($name) <- enabled($name)"); - assertEquals(true, ((Term.Bool) enabledTerm).getValue()); - - Term permsTerm = queryFirstResult(authorizer, "permsfact($name) <- perms($name)"); - assertEquals( - Set.of(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3)), - ((Term.Set) permsTerm).getValue() - ); - } - - private static Term queryFirstResult(Authorizer authorizer, String query) throws Error { - return authorizer.query(query) - .iterator() - .next() - .terms().get(0); - } + @Test + public void testAuthorizerPolicy() throws Parser { + Authorizer authorizer = new Authorizer(); + List policies = authorizer.getPolicies(); + authorizer.deny(); + assertEquals(1, policies.size()); + + authorizer.addPolicy( + new Policy( + Arrays.asList( + constrainedRule( + "deny", + new ArrayList<>(), + new ArrayList<>(), + Arrays.asList(new Expression.Value(new Term.Bool(true))))), + Policy.Kind.DENY)); + assertEquals(2, policies.size()); + + authorizer.addPolicy("deny if true"); + assertEquals(3, policies.size()); + } + + @Test + public void testPuttingSomeFactsInBiscuitAndGettingThemBackOutAgain() throws Exception { + + KeyPair keypair = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, new SecureRandom()); + + Biscuit token = + Biscuit.builder(keypair) + .addAuthorityFact("email(\"bob@example.com\")") + .addAuthorityFact("id(123)") + .addAuthorityFact("enabled(true)") + .addAuthorityFact("perms([1,2,3])") + .build(); + + Authorizer authorizer = + Biscuit.fromBase64Url(token.serializeBase64Url(), keypair.getPublicKey()) + .verify(keypair.getPublicKey()) + .authorizer(); + + Term emailTerm = queryFirstResult(authorizer, "emailfact($name) <- email($name)"); + assertEquals("bob@example.com", ((Term.Str) emailTerm).getValue()); + + Term idTerm = queryFirstResult(authorizer, "idfact($name) <- id($name)"); + assertEquals(123, ((Term.Integer) idTerm).getValue()); + + Term enabledTerm = queryFirstResult(authorizer, "enabledfact($name) <- enabled($name)"); + assertEquals(true, ((Term.Bool) enabledTerm).getValue()); + + Term permsTerm = queryFirstResult(authorizer, "permsfact($name) <- perms($name)"); + assertEquals( + Set.of(new Term.Integer(1), new Term.Integer(2), new Term.Integer(3)), + ((Term.Set) permsTerm).getValue()); + } + + private static Term queryFirstResult(Authorizer authorizer, String query) throws Error { + return authorizer.query(query).iterator().next().terms().get(0); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/BiscuitTest.java b/src/test/java/org/biscuitsec/biscuit/token/BiscuitTest.java index 56f21602..84e73077 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/BiscuitTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/BiscuitTest.java @@ -1,6 +1,31 @@ package org.biscuitsec.biscuit.token; +import static org.biscuitsec.biscuit.crypto.TokenSignature.hex; +import static org.biscuitsec.biscuit.token.builder.Utils.check; +import static org.biscuitsec.biscuit.token.builder.Utils.date; +import static org.biscuitsec.biscuit.token.builder.Utils.fact; +import static org.biscuitsec.biscuit.token.builder.Utils.pred; +import static org.biscuitsec.biscuit.token.builder.Utils.rule; +import static org.biscuitsec.biscuit.token.builder.Utils.str; +import static org.biscuitsec.biscuit.token.builder.Utils.var; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + import biscuit.format.schema.Schema; +import io.vavr.control.Option; +import io.vavr.control.Try; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.SignatureException; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Date; +import java.util.List; + import org.biscuitsec.biscuit.crypto.KeyDelegate; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; @@ -10,703 +35,748 @@ import org.biscuitsec.biscuit.error.FailedCheck; import org.biscuitsec.biscuit.error.LogicError; import org.biscuitsec.biscuit.token.builder.Block; - -import io.vavr.control.Option; -import io.vavr.control.Try; - import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; - -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; -import java.security.SignatureException; -import java.time.Duration; -import java.time.Instant; -import java.util.*; - -import static org.biscuitsec.biscuit.crypto.TokenSignature.hex; -import static org.biscuitsec.biscuit.token.builder.Utils.*; public class BiscuitTest { - @Test - public void testBasic() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testBasic() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block"); + System.out.println("preparing the authority block"); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block authority_builder = new Block(); + Block authorityBuilder = new Block(); - authority_builder.add_fact(fact("right", Arrays.asList(s("file1"), s("read")))); - authority_builder.add_fact(fact("right", Arrays.asList(s("file2"), s("read")))); - authority_builder.add_fact(fact("right", Arrays.asList(s("file1"), s("write")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file1"), str("read")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file2"), str("read")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file1"), str("write")))); - Biscuit b = Biscuit.make(rng, root, authority_builder.build()); + Biscuit b = Biscuit.make(rng, root, authorityBuilder.build()); - System.out.println(b.print()); + System.out.println(b.print()); - System.out.println("serializing the first token"); + System.out.println("serializing the first token"); - byte[] data = b.serialize(); + byte[] data = b.serialize(); - System.out.print("data len: "); - System.out.println(data.length); - System.out.println(hex(data)); + System.out.print("data len: "); + System.out.println(data.length); + System.out.println(hex(data)); - System.out.println("deserializing the first token"); - Biscuit deser = Biscuit.from_bytes(data, root.public_key()); + System.out.println("deserializing the first token"); + Biscuit deser = Biscuit.fromBytes(data, root.getPublicKey()); - System.out.println(deser.print()); + System.out.println(deser.print()); - // SECOND BLOCK - System.out.println("preparing the second block"); + // SECOND BLOCK + System.out.println("preparing the second block"); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder = deser.create_block(); - builder.add_check(check(rule( + Block builder = deser.createBlock(); + builder.addCheck( + check( + rule( "caveat1", - Arrays.asList(var("resource")), + List.of(var("resource")), Arrays.asList( - pred("resource", Arrays.asList(var("resource"))), - pred("operation", Arrays.asList(s("read"))), - pred("right", Arrays.asList(var("resource"), s("read"))) - ) - ))); + pred("resource", List.of(var("resource"))), + pred("operation", List.of(str("read"))), + pred("right", Arrays.asList(var("resource"), str("read"))))))); - Biscuit b2 = deser.attenuate(rng, keypair2, builder); + Biscuit b2 = deser.attenuate(rng, keypair2, builder); - System.out.println(b2.print()); + System.out.println(b2.print()); - System.out.println("serializing the second token"); + System.out.println("serializing the second token"); - byte[] data2 = b2.serialize(); + byte[] data2 = b2.serialize(); - System.out.print("data len: "); - System.out.println(data2.length); - System.out.println(hex(data2)); + System.out.print("data len: "); + System.out.println(data2.length); + System.out.println(hex(data2)); - System.out.println("deserializing the second token"); - Biscuit deser2 = Biscuit.from_bytes(data2, root.public_key()); + System.out.println("deserializing the second token"); + Biscuit deser2 = Biscuit.fromBytes(data2, root.getPublicKey()); - System.out.println(deser2.print()); + System.out.println(deser2.print()); - // THIRD BLOCK - System.out.println("preparing the third block"); + // THIRD BLOCK + System.out.println("preparing the third block"); - KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder3 = deser2.create_block(); - builder3.add_check(check(rule( + Block builder3 = deser2.createBlock(); + builder3.addCheck( + check( + rule( "caveat2", - Arrays.asList(s("file1")), - Arrays.asList( - pred("resource", Arrays.asList(s("file1"))) - ) - ))); - - Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); + List.of(str("file1")), + List.of(pred("resource", List.of(str("file1"))))))); + + Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); + + System.out.println(b3.print()); + + System.out.println("serializing the third token"); + + byte[] data3 = b3.serialize(); + + System.out.print("data len: "); + System.out.println(data3.length); + System.out.println(hex(data3)); + + System.out.println("deserializing the third token"); + Biscuit finalToken = Biscuit.fromBytes(data3, root.getPublicKey()); + + System.out.println(finalToken.print()); + + // check + System.out.println("will check the token for resource=file1 and operation=read"); + + Authorizer authorizer = finalToken.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addFact("operation(\"read\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + + System.out.println("will check the token for resource=file2 and operation=write"); + + Authorizer authorizer2 = finalToken.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addFact("operation(\"write\")"); + authorizer2.addPolicy("allow if true"); + + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, + 0, + "check if resource($resource), operation(\"read\"), right($resource," + + " \"read\")"), + new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")")))), + e); + } + } - System.out.println(b3.print()); + @Test + public void testFolders() throws NoSuchAlgorithmException, Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("serializing the third token"); + System.out.println("preparing the authority block"); - byte[] data3 = b3.serialize(); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - System.out.print("data len: "); - System.out.println(data3.length); - System.out.println(hex(data3)); + org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); - System.out.println("deserializing the third token"); - Biscuit final_token = Biscuit.from_bytes(data3, root.public_key()); + builder.addRight("/folder1/file1", "read"); + builder.addRight("/folder1/file1", "write"); + builder.addRight("/folder1/file2", "read"); + builder.addRight("/folder1/file2", "write"); + builder.addRight("/folder2/file3", "read"); - System.out.println(final_token.print()); + System.out.println(builder.build()); + Biscuit b = builder.build(); - // check - System.out.println("will check the token for resource=file1 and operation=read"); + System.out.println(b.print()); - Authorizer authorizer = final_token.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_fact("operation(\"read\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + Block block2 = b.createBlock(); + block2.resourcePrefix("/folder1/"); + block2.checkRight("read"); - System.out.println("will check the token for resource=file2 and operation=write"); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + Biscuit b2 = b.attenuate(rng, keypair2, block2); - Authorizer authorizer2 = final_token.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_fact("operation(\"write\")"); - authorizer2.add_policy("allow if true"); + Authorizer v1 = b2.authorizer(); + v1.addFact("resource(\"/folder1/file1\")"); + v1.addFact("operation(\"read\")"); + v1.allow(); + v1.authorize(); - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), operation(\"read\"), right($resource, \"read\")"), - new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")") - ))), - e); - } + Authorizer v2 = b2.authorizer(); + v2.addFact("resource(\"/folder2/file3\")"); + v2.addFact("operation(\"read\")"); + v2.allow(); + try { + v2.authorize(); + fail(); + } catch (Error e2) { + // Empty } - @Test - public void testFolders() throws NoSuchAlgorithmException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - System.out.println("preparing the authority block"); - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - - org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); - - builder.add_right("/folder1/file1", "read"); - builder.add_right("/folder1/file1", "write"); - builder.add_right("/folder1/file2", "read"); - builder.add_right("/folder1/file2", "write"); - builder.add_right("/folder2/file3", "read"); - - System.out.println(builder.build()); - Biscuit b = builder.build(); - - System.out.println(b.print()); - - Block block2 = b.create_block(); - block2.resource_prefix("/folder1/"); - block2.check_right("read"); - - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Biscuit b2 = b.attenuate(rng, keypair2, block2); - - Authorizer v1 = b2.authorizer(); - v1.add_fact("resource(\"/folder1/file1\")"); - v1.add_fact("operation(\"read\")"); - v1.allow(); - v1.authorize(); - - Authorizer v2 = b2.authorizer(); - v2.add_fact("resource(\"/folder2/file3\")"); - v2.add_fact("operation(\"read\")"); - v2.allow(); - try { - v2.authorize(); - fail(); - } catch (Error e2) { - // Empty - } - - Authorizer v3 = b2.authorizer(); - v3.add_fact("resource(\"/folder2/file1\")"); - v3.add_fact("operation(\"write\")"); - v3.allow(); - try { - v3.authorize(); - fail(); - } catch (Error e) { - System.out.println(v3.print_world()); - for (FailedCheck f : e.failed_checks().get()) { - System.out.println(f.toString()); - } - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), $resource.starts_with(\"/folder1/\")"), - new FailedCheck.FailedBlock(1, 1, "check if resource($resource), operation(\"read\"), right($resource, \"read\")") - ))), - e); - } + Authorizer v3 = b2.authorizer(); + v3.addFact("resource(\"/folder2/file1\")"); + v3.addFact("operation(\"write\")"); + v3.allow(); + try { + v3.authorize(); + fail(); + } catch (Error e) { + System.out.println(v3.formatWorld()); + for (FailedCheck f : e.getFailedChecks().get()) { + System.out.println(f.toString()); + } + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, + 0, + "check if resource($resource), $resource.starts_with(\"/folder1/\")"), + new FailedCheck.FailedBlock( + 1, + 1, + "check if resource($resource), operation(\"read\"), right($resource," + + " \"read\")")))), + e); } + } - @Test - public void testMultipleAttenuation() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - SecureRandom rng = new SecureRandom(); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - - Block authority_builder = new Block(); - Date date = Date.from(Instant.now()); - authority_builder.add_fact(fact("revocation_id", Arrays.asList(date(date)))); + @Test + public void testMultipleAttenuation() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + SecureRandom rng = new SecureRandom(); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Biscuit biscuit = Biscuit.make(rng, root, authority_builder.build()); + Block authorityBuilder = new Block(); + Date date = Date.from(Instant.now()); + authorityBuilder.addFact(fact("revocation_id", List.of(date(date)))); - Block builder = biscuit.create_block(); - builder.add_fact(fact( - "right", - Arrays.asList(s("topic"), s("tenant"), s("namespace"), s("topic"), s("produce")) - )); + Biscuit biscuit = Biscuit.make(rng, root, authorityBuilder.build()); - String attenuatedB64 = biscuit.attenuate(rng, KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng), builder).serialize_b64url(); + Block builder = biscuit.createBlock(); + builder.addFact( + fact( + "right", + Arrays.asList( + str("topic"), str("tenant"), str("namespace"), str("topic"), str("produce")))); - System.out.println("attenuated: " + attenuatedB64); + String attenuatedB64 = + biscuit + .attenuate(rng, KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng), builder) + .serializeBase64Url(); - Biscuit.from_b64url(attenuatedB64, root.public_key()); - String attenuated2B64 = biscuit.attenuate(rng, KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng), builder).serialize_b64url(); + System.out.println("attenuated: " + attenuatedB64); - System.out.println("attenuated2: " + attenuated2B64); - Biscuit.from_b64url(attenuated2B64, root.public_key()); - } + Biscuit.fromBase64Url(attenuatedB64, root.getPublicKey()); + String attenuated2B64 = + biscuit + .attenuate(rng, KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng), builder) + .serializeBase64Url(); - @Test - public void testReset() throws Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + System.out.println("attenuated2: " + attenuated2B64); + Biscuit.fromBase64Url(attenuated2B64, root.getPublicKey()); + } - System.out.println("preparing the authority block"); + @Test + public void testReset() throws Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + System.out.println("preparing the authority block"); - org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - builder.add_right("/folder1/file1", "read"); - builder.add_right("/folder1/file1", "write"); - builder.add_right("/folder1/file2", "read"); - builder.add_right("/folder1/file2", "write"); - builder.add_right("/folder2/file3", "read"); + org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); - System.out.println(builder.build()); - Biscuit b = builder.build(); + builder.addRight("/folder1/file1", "read"); + builder.addRight("/folder1/file1", "write"); + builder.addRight("/folder1/file2", "read"); + builder.addRight("/folder1/file2", "write"); + builder.addRight("/folder2/file3", "read"); - System.out.println(b.print()); + System.out.println(builder.build()); + Biscuit b = builder.build(); - Block block2 = b.create_block(); - block2.resource_prefix("/folder1/"); - block2.check_right("read"); + System.out.println(b.print()); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Biscuit b2 = b.attenuate(rng, keypair2, block2); + Block block2 = b.createBlock(); + block2.resourcePrefix("/folder1/"); + block2.checkRight("read"); - Authorizer v1 = b2.authorizer(); - v1.allow(); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + Biscuit b2 = b.attenuate(rng, keypair2, block2); - Authorizer v2 = v1.clone(); + Authorizer v1 = b2.authorizer(); + v1.allow(); - v2.add_fact("resource(\"/folder1/file1\")"); - v2.add_fact("operation(\"read\")"); + Authorizer v2 = v1.clone(); + v2.addFact("resource(\"/folder1/file1\")"); + v2.addFact("operation(\"read\")"); - v2.authorize(); + v2.authorize(); - Authorizer v3 = v1.clone(); + Authorizer v3 = v1.clone(); - v3.add_fact("resource(\"/folder2/file3\")"); - v3.add_fact("operation(\"read\")"); + v3.addFact("resource(\"/folder2/file3\")"); + v3.addFact("operation(\"read\")"); - Try res = Try.of(() -> v3.authorize()); - System.out.println(v3.print_world()); + Try res = Try.of(() -> v3.authorize()); + System.out.println(v3.formatWorld()); - assertTrue(res.isFailure()); + assertTrue(res.isFailure()); - Authorizer v4 = v1.clone(); + Authorizer v4 = v1.clone(); - v4.add_fact("resource(\"/folder2/file1\")"); - v4.add_fact("operation(\"write\")"); + v4.addFact("resource(\"/folder2/file1\")"); + v4.addFact("operation(\"write\")"); - Error e = (Error) Try.of(() -> v4.authorize()).getCause(); + Error e = (Error) Try.of(() -> v4.authorize()).getCause(); - System.out.println(v4.print_world()); - for (FailedCheck f : e.failed_checks().get()) { - System.out.println(f.toString()); - } - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), $resource.starts_with(\"/folder1/\")"), - new FailedCheck.FailedBlock(1, 1, "check if resource($resource), operation(\"read\"), right($resource, \"read\")") - ))), - e); + System.out.println(v4.formatWorld()); + for (FailedCheck f : e.getFailedChecks().get()) { + System.out.println(f.toString()); } + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, 0, "check if resource($resource), $resource.starts_with(\"/folder1/\")"), + new FailedCheck.FailedBlock( + 1, + 1, + "check if resource($resource), operation(\"read\"), right($resource," + + " \"read\")")))), + e); + } - @Test - public void testEmptyAuthorizer() throws Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testEmptyAuthorizer() throws Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block"); + System.out.println("preparing the authority block"); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); + org.biscuitsec.biscuit.token.builder.Biscuit builder = Biscuit.builder(rng, root); - builder.add_right("/folder1/file1", "read"); - builder.add_right("/folder1/file1", "write"); - builder.add_right("/folder1/file2", "read"); - builder.add_right("/folder1/file2", "write"); - builder.add_right("/folder2/file3", "read"); + builder.addRight("/folder1/file1", "read"); + builder.addRight("/folder1/file1", "write"); + builder.addRight("/folder1/file2", "read"); + builder.addRight("/folder1/file2", "write"); + builder.addRight("/folder2/file3", "read"); - System.out.println(builder.build()); - Biscuit b = builder.build(); + System.out.println(builder.build()); + Biscuit b = builder.build(); - System.out.println(b.print()); + System.out.println(b.print()); - Block block2 = b.create_block(); - block2.resource_prefix("/folder1/"); - block2.check_right("read"); + Block block2 = b.createBlock(); + block2.resourcePrefix("/folder1/"); + block2.checkRight("read"); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Biscuit b2 = b.attenuate(rng, keypair2, block2); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + Biscuit b2 = b.attenuate(rng, keypair2, block2); - Authorizer v1 = new Authorizer(); - v1.allow(); + Authorizer v1 = new Authorizer(); + v1.allow(); - v1.authorize(); + v1.authorize(); - v1.add_token(b2); + v1.addToken(b2); - v1.add_fact("resource(\"/folder2/file1\")"); - v1.add_fact("operation(\"write\")"); + v1.addFact("resource(\"/folder2/file1\")"); + v1.addFact("operation(\"write\")"); - assertTrue(Try.of(()-> v1.authorize()).isFailure()); - } + assertTrue(Try.of(() -> v1.authorize()).isFailure()); + } - @Test - public void testBasicWithNamespaces() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testBasicWithNamespaces() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block"); + System.out.println("preparing the authority block"); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block authority_builder = new Block(); + Block authorityBuilder = new Block(); - authority_builder.add_fact(fact("namespace:right", Arrays.asList(s("file1"), s("read")))); - authority_builder.add_fact(fact("namespace:right", Arrays.asList(s("file1"), s("write")))); - authority_builder.add_fact(fact("namespace:right", Arrays.asList(s("file2"), s("read")))); - Biscuit b = Biscuit.make(rng, root, authority_builder.build()); + authorityBuilder.addFact(fact("namespace:right", Arrays.asList(str("file1"), str("read")))); + authorityBuilder.addFact(fact("namespace:right", Arrays.asList(str("file1"), str("write")))); + authorityBuilder.addFact(fact("namespace:right", Arrays.asList(str("file2"), str("read")))); + Biscuit b = Biscuit.make(rng, root, authorityBuilder.build()); - System.out.println(b.print()); + System.out.println(b.print()); - System.out.println("serializing the first token"); + System.out.println("serializing the first token"); - byte[] data = b.serialize(); + byte[] data = b.serialize(); - System.out.print("data len: "); - System.out.println(data.length); - System.out.println(hex(data)); + System.out.print("data len: "); + System.out.println(data.length); + System.out.println(hex(data)); - System.out.println("deserializing the first token"); - Biscuit deser = Biscuit.from_bytes(data, root.public_key()); + System.out.println("deserializing the first token"); + Biscuit deser = Biscuit.fromBytes(data, root.getPublicKey()); - System.out.println(deser.print()); + System.out.println(deser.print()); - // SECOND BLOCK - System.out.println("preparing the second block"); + // SECOND BLOCK + System.out.println("preparing the second block"); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder = deser.create_block(); - builder.add_check(check(rule( + Block builder = deser.createBlock(); + builder.addCheck( + check( + rule( "caveat1", - Arrays.asList(var("resource")), + List.of(var("resource")), Arrays.asList( - pred("resource", Arrays.asList(var("resource"))), - pred("operation", Arrays.asList(s("read"))), - pred("namespace:right", Arrays.asList(var("resource"), s("read"))) - ) - ))); + pred("resource", List.of(var("resource"))), + pred("operation", List.of(str("read"))), + pred("namespace:right", Arrays.asList(var("resource"), str("read"))))))); - Biscuit b2 = deser.attenuate(rng, keypair2, builder); + Biscuit b2 = deser.attenuate(rng, keypair2, builder); - System.out.println(b2.print()); + System.out.println(b2.print()); - System.out.println("serializing the second token"); + System.out.println("serializing the second token"); - byte[] data2 = b2.serialize(); + byte[] data2 = b2.serialize(); - System.out.print("data len: "); - System.out.println(data2.length); - System.out.println(hex(data2)); + System.out.print("data len: "); + System.out.println(data2.length); + System.out.println(hex(data2)); - System.out.println("deserializing the second token"); - Biscuit deser2 = Biscuit.from_bytes(data2, root.public_key()); + System.out.println("deserializing the second token"); + Biscuit deser2 = Biscuit.fromBytes(data2, root.getPublicKey()); - System.out.println(deser2.print()); + System.out.println(deser2.print()); - // THIRD BLOCK - System.out.println("preparing the third block"); + // THIRD BLOCK + System.out.println("preparing the third block"); - KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder3 = deser2.create_block(); - builder3.add_check(check(rule( + Block builder3 = deser2.createBlock(); + builder3.addCheck( + check( + rule( "caveat2", - Arrays.asList(s("file1")), - Arrays.asList( - pred("resource", Arrays.asList(s("file1"))) - ) - ))); - - Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); - - System.out.println(b3.print()); - - System.out.println("serializing the third token"); - - byte[] data3 = b3.serialize(); - - System.out.print("data len: "); - System.out.println(data3.length); - System.out.println(hex(data3)); - - System.out.println("deserializing the third token"); - Biscuit final_token = Biscuit.from_bytes(data3, root.public_key()); - - System.out.println(final_token.print()); - - // check - System.out.println("will check the token for resource=file1 and operation=read"); - - Authorizer authorizer = final_token.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_fact("operation(\"read\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - - System.out.println("will check the token for resource=file2 and operation=write"); - - Authorizer authorizer2 = final_token.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_fact("operation(\"write\")"); - authorizer2.add_policy("allow if true"); - - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), operation(\"read\"), namespace:right($resource, \"read\")"), - new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")") - ))), - e); - } + List.of(str("file1")), + List.of(pred("resource", List.of(str("file1"))))))); + + Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); + + System.out.println(b3.print()); + + System.out.println("serializing the third token"); + + byte[] data3 = b3.serialize(); + + System.out.print("data len: "); + System.out.println(data3.length); + System.out.println(hex(data3)); + + System.out.println("deserializing the third token"); + Biscuit finalToken = Biscuit.fromBytes(data3, root.getPublicKey()); + + System.out.println(finalToken.print()); + + // check + System.out.println("will check the token for resource=file1 and operation=read"); + + Authorizer authorizer = finalToken.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addFact("operation(\"read\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + + System.out.println("will check the token for resource=file2 and operation=write"); + + Authorizer authorizer2 = finalToken.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addFact("operation(\"write\")"); + authorizer2.addPolicy("allow if true"); + + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, + 0, + "check if resource($resource), operation(\"read\")," + + " namespace:right($resource, \"read\")"), + new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")")))), + e); } + } - @Test - public void testBasicWithNamespacesWithAddAuthorityFact() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testBasicWithNamespacesWithAddAuthorityFact() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block"); + System.out.println("preparing the authority block"); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - SymbolTable symbols = Biscuit.default_symbol_table(); - org.biscuitsec.biscuit.token.builder.Biscuit o = new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root); - o.add_authority_fact("namespace:right(\"file1\",\"read\")"); - o.add_authority_fact("namespace:right(\"file1\",\"write\")"); - o.add_authority_fact("namespace:right(\"file2\",\"read\")"); - Biscuit b = o.build(); + org.biscuitsec.biscuit.token.builder.Biscuit o = + new org.biscuitsec.biscuit.token.builder.Biscuit(rng, root); + o.addAuthorityFact("namespace:right(\"file1\",\"read\")"); + o.addAuthorityFact("namespace:right(\"file1\",\"write\")"); + o.addAuthorityFact("namespace:right(\"file2\",\"read\")"); + Biscuit b = o.build(); - System.out.println(b.print()); + System.out.println(b.print()); - System.out.println("serializing the first token"); + System.out.println("serializing the first token"); - byte[] data = b.serialize(); + byte[] data = b.serialize(); - System.out.print("data len: "); - System.out.println(data.length); - System.out.println(hex(data)); + System.out.print("data len: "); + System.out.println(data.length); + System.out.println(hex(data)); - System.out.println("deserializing the first token"); - Biscuit deser = Biscuit.from_bytes(data, root.public_key()); + System.out.println("deserializing the first token"); + Biscuit deser = Biscuit.fromBytes(data, root.getPublicKey()); - System.out.println(deser.print()); + System.out.println(deser.print()); - // SECOND BLOCK - System.out.println("preparing the second block"); + // SECOND BLOCK + System.out.println("preparing the second block"); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder = deser.create_block(); - builder.add_check(check(rule( + Block builder = deser.createBlock(); + builder.addCheck( + check( + rule( "caveat1", - Arrays.asList(var("resource")), + List.of(var("resource")), Arrays.asList( - pred("resource", Arrays.asList(var("resource"))), - pred("operation", Arrays.asList(s("read"))), - pred("namespace:right", Arrays.asList(var("resource"), s("read"))) - ) - ))); + pred("resource", List.of(var("resource"))), + pred("operation", List.of(str("read"))), + pred("namespace:right", Arrays.asList(var("resource"), str("read"))))))); - Biscuit b2 = deser.attenuate(rng, keypair2, builder); + Biscuit b2 = deser.attenuate(rng, keypair2, builder); - System.out.println(b2.print()); + System.out.println(b2.print()); - System.out.println("serializing the second token"); + System.out.println("serializing the second token"); - byte[] data2 = b2.serialize(); + byte[] data2 = b2.serialize(); - System.out.print("data len: "); - System.out.println(data2.length); - System.out.println(hex(data2)); + System.out.print("data len: "); + System.out.println(data2.length); + System.out.println(hex(data2)); - System.out.println("deserializing the second token"); - Biscuit deser2 = Biscuit.from_bytes(data2, root.public_key()); + System.out.println("deserializing the second token"); + Biscuit deser2 = Biscuit.fromBytes(data2, root.getPublicKey()); - System.out.println(deser2.print()); + System.out.println(deser2.print()); - // THIRD BLOCK - System.out.println("preparing the third block"); + // THIRD BLOCK + System.out.println("preparing the third block"); - KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block builder3 = deser2.create_block(); - builder3.add_check(check(rule( + Block builder3 = deser2.createBlock(); + builder3.addCheck( + check( + rule( "caveat2", - Arrays.asList(s("file1")), - Arrays.asList( - pred("resource", Arrays.asList(s("file1"))) - ) - ))); - - Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); - - System.out.println(b3.print()); - - System.out.println("serializing the third token"); - - byte[] data3 = b3.serialize(); - - System.out.print("data len: "); - System.out.println(data3.length); - System.out.println(hex(data3)); - - System.out.println("deserializing the third token"); - Biscuit final_token = Biscuit.from_bytes(data3, root.public_key()); - - System.out.println(final_token.print()); - - // check - System.out.println("will check the token for resource=file1 and operation=read"); - - Authorizer authorizer = final_token.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_fact("operation(\"read\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - - System.out.println("will check the token for resource=file2 and operation=write"); - - Authorizer authorizer2 = final_token.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_fact("operation(\"write\")"); - authorizer2.add_policy("allow if true"); - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), operation(\"read\"), namespace:right($resource, \"read\")"), - new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")") - ))), - e); - } + List.of(str("file1")), + List.of(pred("resource", List.of(str("file1"))))))); + + Biscuit b3 = deser2.attenuate(rng, keypair3, builder3); + + System.out.println(b3.print()); + + System.out.println("serializing the third token"); + + byte[] data3 = b3.serialize(); + + System.out.print("data len: "); + System.out.println(data3.length); + System.out.println(hex(data3)); + + System.out.println("deserializing the third token"); + Biscuit finalToken = Biscuit.fromBytes(data3, root.getPublicKey()); + + System.out.println(finalToken.print()); + + // check + System.out.println("will check the token for resource=file1 and operation=read"); + + Authorizer authorizer = finalToken.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addFact("operation(\"read\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + + System.out.println("will check the token for resource=file2 and operation=write"); + + Authorizer authorizer2 = finalToken.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addFact("operation(\"write\")"); + authorizer2.addPolicy("allow if true"); + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, + 0, + "check if resource($resource), operation(\"read\")," + + " namespace:right($resource, \"read\")"), + new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")")))), + e); } + } - @Test - public void testRootKeyId() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testRootKeyId() + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block"); + System.out.println("preparing the authority block"); - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block authority_builder = new Block(); + Block authorityBuilder = new Block(); - authority_builder.add_fact(fact("right", Arrays.asList(s("file1"), s("read")))); - authority_builder.add_fact(fact("right", Arrays.asList(s("file2"), s("read")))); - authority_builder.add_fact(fact("right", Arrays.asList(s("file1"), s("write")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file1"), str("read")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file2"), str("read")))); + authorityBuilder.addFact(fact("right", Arrays.asList(str("file1"), str("write")))); - Biscuit b = Biscuit.make(rng, root, 1, authority_builder.build()); + Biscuit b = Biscuit.make(rng, root, 1, authorityBuilder.build()); - System.out.println(b.print()); + System.out.println(b.print()); - System.out.println("serializing the first token"); + System.out.println("serializing the first token"); - byte[] data = b.serialize(); + byte[] data = b.serialize(); - System.out.print("data len: "); - System.out.println(data.length); - System.out.println(hex(data)); + System.out.print("data len: "); + System.out.println(data.length); + System.out.println(hex(data)); - System.out.println("deserializing the first token"); + System.out.println("deserializing the first token"); - assertThrows(InvalidKeyException.class, () -> { - Biscuit deser = Biscuit.from_bytes(data, new KeyDelegate() { - @Override - public Option root_key(Option key_id) { - return Option.none(); - } - }); + assertThrows( + InvalidKeyException.class, + () -> { + Biscuit deser = + Biscuit.fromBytes( + data, + new KeyDelegate() { + @Override + public Option getRootKey(Option keyId) { + return Option.none(); + } + }); }); - - assertThrows(Error.FormatError.Signature.InvalidSignature.class, () -> { - Biscuit deser = Biscuit.from_bytes(data, new KeyDelegate() { - @Override - public Option root_key(Option key_id) { - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - return Option.some(root.public_key()); - } - }); + assertThrows( + Error.FormatError.Signature.InvalidSignature.class, + () -> { + Biscuit deser = + Biscuit.fromBytes( + data, + new KeyDelegate() { + @Override + public Option getRootKey(Option keyId) { + + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + return Option.some(root.getPublicKey()); + } + }); }); - Biscuit deser = Biscuit.from_bytes(data, new KeyDelegate() { - @Override - public Option root_key(Option key_id) { - if (key_id.get() == 1) { - return Option.some(root.public_key()); + Biscuit deser = + Biscuit.fromBytes( + data, + new KeyDelegate() { + @Override + public Option getRootKey(Option keyId) { + if (keyId.get() == 1) { + return Option.some(root.getPublicKey()); } else { - return Option.none(); + return Option.none(); } - } - }); - + } + }); + } + + @Test + public void testCheckAll() + throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); + + System.out.println("preparing the authority block"); + + KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + + Biscuit biscuit = + Biscuit.builder(root) + .addAuthorityCheck( + "check all operation($op), allowed_operations($allowed), $allowed.contains($op)") + .build(); + Authorizer authorizer = biscuit.verify(root.getPublicKey()).authorizer(); + authorizer.addFact("operation(\"read\")"); + authorizer.addFact("operation(\"write\")"); + authorizer.addFact("allowed_operations([\"write\"])"); + authorizer.addPolicy("allow if true"); + + try { + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error.FailedLogic e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + List.of( + new FailedCheck.FailedBlock( + 0, + 0, + "check all operation($op), allowed_operations($allowed)," + + " $allowed.contains($op)")))), + e); } - @Test - public void testCheckAll() throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - System.out.println("preparing the authority block"); - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - - Biscuit biscuit = Biscuit.builder(root) - .add_authority_check("check all operation($op), allowed_operations($allowed), $allowed.contains($op)") - .build(); - Authorizer authorizer = biscuit.verify(root.public_key()).authorizer(); - authorizer.add_fact("operation(\"read\")"); - authorizer.add_fact("operation(\"write\")"); - authorizer.add_fact("allowed_operations([\"write\"])"); - authorizer.add_policy("allow if true"); - - try { - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch(Error.FailedLogic e) { - System.out.println(e); - assertEquals(new Error.FailedLogic(new LogicError.Unauthorized( - new LogicError.MatchedPolicy.Allow(0), - Arrays.asList( - new FailedCheck.FailedBlock(0, 0, "check all operation($op), allowed_operations($allowed), $allowed.contains($op)") - ) - )), e); - } - - Authorizer authorizer2 = biscuit.verify(root.public_key()).authorizer(); - authorizer2.add_fact("operation(\"read\")"); - authorizer2.add_fact("operation(\"write\")"); - authorizer2.add_fact("allowed_operations([\"read\", \"write\"])"); - authorizer2.add_policy("allow if true"); - - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } + Authorizer authorizer2 = biscuit.verify(root.getPublicKey()).authorizer(); + authorizer2.addFact("operation(\"read\")"); + authorizer2.addFact("operation(\"write\")"); + authorizer2.addFact("allowed_operations([\"read\", \"write\"])"); + authorizer2.addPolicy("allow if true"); + + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/ExampleTest.java b/src/test/java/org/biscuitsec/biscuit/token/ExampleTest.java index 17e83381..70faef91 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/ExampleTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/ExampleTest.java @@ -1,44 +1,43 @@ package org.biscuitsec.biscuit.token; import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.crypto.KeyPair; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.token.builder.Block; - import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.SignatureException; +import org.biscuitsec.biscuit.crypto.KeyPair; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.token.builder.Block; /* example code for the documentation at https://www.biscuitsec.org * if these functions change, please send a PR to update them at https://github.com/biscuit-auth/website */ public class ExampleTest { - public KeyPair root() { - return KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519); - } + public KeyPair root() { + return KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519); + } - public Biscuit createToken(KeyPair root) throws Error { - return Biscuit.builder(root) - .add_authority_fact("user(\"1234\")") - .add_authority_check("check if operation(\"read\")") - .build(); - } + public Biscuit createToken(KeyPair root) throws Error { + return Biscuit.builder(root) + .addAuthorityFact("user(\"1234\")") + .addAuthorityCheck("check if operation(\"read\")") + .build(); + } - public Long authorize(KeyPair root, byte[] serializedToken) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - return Biscuit.from_bytes(serializedToken, root.public_key()).authorizer() - .add_fact("resource(\"/folder1/file1\")") - .add_fact("operation(\"read\")") - .allow() - .authorize(); - } + public Long authorize(KeyPair root, byte[] serializedToken) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + return Biscuit.fromBytes(serializedToken, root.getPublicKey()) + .authorizer() + .addFact("resource(\"/folder1/file1\")") + .addFact("operation(\"read\")") + .allow() + .authorize(); + } - public Biscuit attenuate(KeyPair root, byte[] serializedToken) throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { - Biscuit token = Biscuit.from_bytes(serializedToken, root.public_key()); - Block block = token.create_block().add_check("check if operation(\"read\")"); - return token.attenuate(block, root.public_key().algorithm); - } + public Biscuit attenuate(KeyPair root, byte[] serializedToken) + throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, Error { + Biscuit token = Biscuit.fromBytes(serializedToken, root.getPublicKey()); + Block block = token.createBlock().addCheck("check if operation(\"read\")"); + return token.attenuate(block, root.getPublicKey().getAlgorithm()); + } - /*public Set query(Authorizer authorizer) throws Error.Timeout, Error.TooManyFacts, Error.TooManyIterations, Error.Parser { - return authorizer.query("data($name, $id) <- user($name, $id)"); - }*/ } diff --git a/src/test/java/org/biscuitsec/biscuit/token/KmsSignerExampleTest.java b/src/test/java/org/biscuitsec/biscuit/token/KmsSignerExampleTest.java index a369e344..07fe7e24 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/KmsSignerExampleTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/KmsSignerExampleTest.java @@ -1,6 +1,10 @@ package org.biscuitsec.biscuit.token; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + import biscuit.format.schema.Schema.PublicKey.Algorithm; +import java.io.ByteArrayInputStream; +import java.io.IOException; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.crypto.Signer; import org.biscuitsec.biscuit.error.Error; @@ -22,90 +26,95 @@ import software.amazon.awssdk.services.kms.model.KeyUsageType; import software.amazon.awssdk.services.kms.model.SigningAlgorithmSpec; -import java.io.ByteArrayInputStream; -import java.io.IOException; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; - @Testcontainers public class KmsSignerExampleTest { - private static final DockerImageName LOCALSTACK_IMAGE = DockerImageName.parse("localstack/localstack:4.0.3"); - - @Container - public static LocalStackContainer LOCALSTACK = new LocalStackContainer(LOCALSTACK_IMAGE) - .withServices(LocalStackContainer.Service.KMS); - - private KmsClient kmsClient; - private String kmsKeyId; - - @BeforeEach - public void setup() { - var credentials = AwsBasicCredentials.create(LOCALSTACK.getAccessKey(), LOCALSTACK.getSecretKey()); - kmsClient = KmsClient.builder() - .endpointOverride(LOCALSTACK.getEndpointOverride(LocalStackContainer.Service.KMS)) - .credentialsProvider(StaticCredentialsProvider.create(credentials)) - .region(Region.of(LOCALSTACK.getRegion())) - .build(); - - // ECC_NIST_P256 == SECP256R1 - kmsKeyId = kmsClient.createKey(b -> b - .keySpec(KeySpec.ECC_NIST_P256) - .keyUsage(KeyUsageType.SIGN_VERIFY) - .build() - ).keyMetadata().keyId(); - } - - @Test - public void testCreateBiscuitWithRemoteSigner() throws Error { - var getPublicKeyResponse = kmsClient.getPublicKey(b -> b.keyId(kmsKeyId).build()); - var x509EncodedPublicKey = getPublicKeyResponse.publicKey().asByteArray(); - var sec1CompressedEncodedPublicKey = convertDEREncodedX509PublicKeyToSEC1CompressedEncodedPublicKey(x509EncodedPublicKey); - var publicKey = new PublicKey(Algorithm.SECP256R1, sec1CompressedEncodedPublicKey); - var signer = new Signer() { - @Override - public byte[] sign(byte[] bytes) { - var signResponse = kmsClient.sign(b -> b - .keyId(kmsKeyId) - .signingAlgorithm(SigningAlgorithmSpec.ECDSA_SHA_256) - .message(SdkBytes.fromByteArray(bytes)) - ); - return signResponse.signature().asByteArray(); - } - - @Override - public PublicKey public_key() { - return publicKey; - } + private static final DockerImageName LOCALSTACK_IMAGE = + DockerImageName.parse("localstack/localstack:4.0.3"); + + @Container + public static LocalStackContainer LOCALSTACK = + new LocalStackContainer(LOCALSTACK_IMAGE).withServices(LocalStackContainer.Service.KMS); + + private KmsClient kmsClient; + private String kmsKeyId; + + @BeforeEach + public void setup() { + var credentials = + AwsBasicCredentials.create(LOCALSTACK.getAccessKey(), LOCALSTACK.getSecretKey()); + kmsClient = + KmsClient.builder() + .endpointOverride(LOCALSTACK.getEndpointOverride(LocalStackContainer.Service.KMS)) + .credentialsProvider(StaticCredentialsProvider.create(credentials)) + .region(Region.of(LOCALSTACK.getRegion())) + .build(); + + // ECC_NIST_P256 == SECP256R1 + kmsKeyId = + kmsClient + .createKey( + b -> b.keySpec(KeySpec.ECC_NIST_P256).keyUsage(KeyUsageType.SIGN_VERIFY).build()) + .keyMetadata() + .keyId(); + } + + @Test + public void testCreateBiscuitWithRemoteSigner() throws Error { + var getPublicKeyResponse = kmsClient.getPublicKey(b -> b.keyId(kmsKeyId).build()); + var x509EncodedPublicKey = getPublicKeyResponse.publicKey().asByteArray(); + var sec1CompressedEncodedPublicKey = + convertDerEncodedX509PublicKeyToSec1CompressedEncodedPublicKey(x509EncodedPublicKey); + var publicKey = new PublicKey(Algorithm.SECP256R1, sec1CompressedEncodedPublicKey); + var signer = + new Signer() { + @Override + public byte[] sign(byte[] bytes) { + var signResponse = + kmsClient.sign( + b -> + b.keyId(kmsKeyId) + .signingAlgorithm(SigningAlgorithmSpec.ECDSA_SHA_256) + .message(SdkBytes.fromByteArray(bytes))); + return signResponse.signature().asByteArray(); + } + + @Override + public PublicKey getPublicKey() { + return publicKey; + } }; - var biscuit = Biscuit.builder(signer) - .add_authority_fact("user(\"1234\")") - .add_authority_check("check if operation(\"read\")") - .build(); - var serializedBiscuit = biscuit.serialize(); - var deserializedUnverifiedBiscuit = Biscuit.from_bytes(serializedBiscuit); - var verifiedBiscuit = assertDoesNotThrow(() -> deserializedUnverifiedBiscuit.verify(publicKey)); - - System.out.println(verifiedBiscuit.print()); - } - - private static byte[] convertDEREncodedX509PublicKeyToSEC1CompressedEncodedPublicKey(byte[] publicKeyBytes) { - try (ASN1InputStream asn1InputStream = new ASN1InputStream(new ByteArrayInputStream(publicKeyBytes))) { - - // Parse the ASN.1 encoded public key bytes - var asn1Primitive = asn1InputStream.readObject(); - var subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(asn1Primitive); - - // Extract the public key data - var publicKeyDataBitString = subjectPublicKeyInfo.getPublicKeyData(); - byte[] publicKeyData = publicKeyDataBitString.getBytes(); - - // Parse the public key data to get the elliptic curve point - var ecParameters = ECNamedCurveTable.getByName("secp256r1"); - var ecPoint = ecParameters.getCurve().decodePoint(publicKeyData); - return ecPoint.getEncoded(true); - } catch (IOException e) { - throw new RuntimeException("Error converting DER-encoded X.509 to SEC1 compressed format", e); - } + var biscuit = + Biscuit.builder(signer) + .addAuthorityFact("user(\"1234\")") + .addAuthorityCheck("check if operation(\"read\")") + .build(); + var serializedBiscuit = biscuit.serialize(); + var deserializedUnverifiedBiscuit = Biscuit.fromBytes(serializedBiscuit); + var verifiedBiscuit = assertDoesNotThrow(() -> deserializedUnverifiedBiscuit.verify(publicKey)); + + System.out.println(verifiedBiscuit.print()); + } + + private static byte[] convertDerEncodedX509PublicKeyToSec1CompressedEncodedPublicKey( + byte[] publicKeyBytes) { + try (ASN1InputStream asn1InputStream = + new ASN1InputStream(new ByteArrayInputStream(publicKeyBytes))) { + + // Parse the ASN.1 encoded public key bytes + var asn1Primitive = asn1InputStream.readObject(); + var subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(asn1Primitive); + + // Extract the public key data + var publicKeyDataBitString = subjectPublicKeyInfo.getPublicKeyData(); + byte[] publicKeyData = publicKeyDataBitString.getBytes(); + + // Parse the public key data to get the elliptic curve point + var ecParameters = ECNamedCurveTable.getByName("secp256r1"); + var ecPoint = ecParameters.getCurve().decodePoint(publicKeyData); + return ecPoint.getEncoded(true); + } catch (IOException e) { + throw new RuntimeException("Error converting DER-encoded X.509 to SEC1 compressed format", e); } + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java index 776dd93d..addfcd77 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/SamplesTest.java @@ -1,601 +1,717 @@ package org.biscuitsec.biscuit.token; +import static org.biscuitsec.biscuit.token.Block.fromBytes; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + import biscuit.format.schema.Schema; -import com.google.gson.*; +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; import io.vavr.Tuple2; +import io.vavr.control.Either; import io.vavr.control.Option; +import io.vavr.control.Try; +import java.io.BufferedInputStream; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.security.SecureRandom; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.crypto.PublicKey; import org.biscuitsec.biscuit.datalog.Rule; import org.biscuitsec.biscuit.datalog.RunLimits; import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; -import io.vavr.control.Either; -import io.vavr.control.Try; import org.biscuitsec.biscuit.token.builder.Check; import org.biscuitsec.biscuit.token.builder.parser.Parser; -import org.biscuitsec.biscuit.token.format.SerializedBiscuit; import org.biscuitsec.biscuit.token.format.SignedBlock; import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.TestFactory; -import java.io.BufferedInputStream; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.security.SecureRandom; -import java.time.Duration; -import java.util.*; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.Stream; - -import static org.biscuitsec.biscuit.token.Block.from_bytes; -import static org.junit.jupiter.api.Assertions.*; - class SamplesTest { - final RunLimits runLimits = new RunLimits(500,100, Duration.ofMillis(500)); - @TestFactory - Stream jsonTest() { - InputStream inputStream = - Thread.currentThread().getContextClassLoader().getResourceAsStream("samples/samples.json"); - Gson gson = new Gson(); - Sample sample = gson.fromJson(new InputStreamReader(new BufferedInputStream(inputStream)), Sample.class); - PublicKey publicKey = new PublicKey(Schema.PublicKey.Algorithm.Ed25519, sample.root_public_key); - KeyPair keyPair = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, sample.root_private_key); - return sample.testcases.stream().map(t -> process_testcase(t, publicKey, keyPair)); - } - - void compareBlocks(KeyPair root, List sampleBlocks, Biscuit token) throws Error { - assertEquals(sampleBlocks.size(), 1+token.blocks.size()); - Option sampleToken = Option.none(); - Biscuit b = compareBlock(root, sampleToken, 0, sampleBlocks.get(0), token.authority, token.symbols); - sampleToken = Option.some(b); - - for(int i=0; i < token.blocks.size(); i++) { - b = compareBlock(root, sampleToken, i+1, sampleBlocks.get(i+1), token.blocks.get(i), token.symbols); - sampleToken = Option.some(b); - } + final RunLimits runLimits = new RunLimits(500, 100, Duration.ofMillis(500)); + + @TestFactory + Stream jsonTest() { + InputStream inputStream = + Thread.currentThread().getContextClassLoader().getResourceAsStream("samples/samples.json"); + Gson gson = new Gson(); + Sample sample = + gson.fromJson(new InputStreamReader(new BufferedInputStream(inputStream)), Sample.class); + PublicKey publicKey = new PublicKey(Schema.PublicKey.Algorithm.Ed25519, sample.root_public_key); + KeyPair keyPair = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, sample.root_private_key); + return sample.testcases.stream().map(t -> processTestcase(t, publicKey, keyPair)); + } + + void compareBlocks(KeyPair root, List sampleBlocks, Biscuit token) throws Error { + assertEquals(sampleBlocks.size(), 1 + token.blocks.size()); + Option sampleToken = Option.none(); + Biscuit b = + compareBlock(root, sampleToken, 0, sampleBlocks.get(0), token.authority, token.symbolTable); + sampleToken = Option.some(b); + + for (int i = 0; i < token.blocks.size(); i++) { + b = + compareBlock( + root, + sampleToken, + i + 1, + sampleBlocks.get(i + 1), + token.blocks.get(i), + token.symbolTable); + sampleToken = Option.some(b); + } + } + + Biscuit compareBlock( + KeyPair root, + Option sampleToken, + long sampleBlockIndex, + Block sampleBlock, + org.biscuitsec.biscuit.token.Block tokenBlock, + SymbolTable tokenSymbols) + throws Error { + Option sampleExternalKey = sampleBlock.getExternalKey(); + List samplePublicKeys = sampleBlock.getPublicKeys(); + String sampleDatalog = sampleBlock.getCode().replace("\"", "\\\""); + + Either< + Map>, + org.biscuitsec.biscuit.token.builder.Block> + outputSample = Parser.datalog(sampleBlockIndex, sampleDatalog); + + // the invalid block rule with unbound variable cannot be parsed + if (outputSample.isLeft()) { + return sampleToken.get(); } - Biscuit compareBlock(KeyPair root, Option sampleToken, long sampleBlockIndex, Block sampleBlock, org.biscuitsec.biscuit.token.Block tokenBlock, SymbolTable tokenSymbols) throws Error { - Option sampleExternalKey = sampleBlock.getExternalKey(); - List samplePublicKeys = sampleBlock.getPublicKeys(); - String sampleDatalog = sampleBlock.getCode().replace("\"","\\\""); - - Either>, org.biscuitsec.biscuit.token.builder.Block> outputSample = Parser.datalog(sampleBlockIndex, sampleDatalog); - - // the invalid block rule with unbound variable cannot be parsed - if(outputSample.isLeft()) { - return sampleToken.get(); - } - - Biscuit newSampleToken; - if(!sampleToken.isDefined()) { - org.biscuitsec.biscuit.token.builder.Biscuit builder = new org.biscuitsec.biscuit.token.builder.Biscuit(new SecureRandom(), root, Option.none(), outputSample.get()); - newSampleToken = builder.build(); - } else { - Biscuit s = sampleToken.get(); - newSampleToken = s.attenuate(outputSample.get(), Schema.PublicKey.Algorithm.Ed25519); - } - - org.biscuitsec.biscuit.token.Block generatedSampleBlock; - if(!sampleToken.isDefined()) { - generatedSampleBlock = newSampleToken.authority; - } else { - generatedSampleBlock = newSampleToken.blocks.get((int)sampleBlockIndex-1); - } - - System.out.println("generated block: "); - System.out.println(generatedSampleBlock.print(newSampleToken.symbols)); - System.out.println("deserialized block: "); - System.out.println(tokenBlock.print(newSampleToken.symbols)); - - SymbolTable tokenBlockSymbols = tokenSymbols; - SymbolTable generatedBlockSymbols = newSampleToken.symbols; - assertEquals(generatedSampleBlock.printCode(generatedBlockSymbols), tokenBlock.printCode(tokenBlockSymbols)); - - /* FIXME: to generate the same sample block, we need the samples to provide the external private key - assertEquals(generatedSampleBlock, tokenBlock); - assertArrayEquals(generatedSampleBlock.to_bytes().get(), tokenBlock.to_bytes().get()); - */ - - return newSampleToken; - } - - DynamicTest process_testcase(final TestCase testCase, final PublicKey publicKey, final KeyPair privateKey) { - return DynamicTest.dynamicTest(testCase.title + ": "+testCase.filename, () -> { - System.out.println("Testcase name: \""+testCase.title+"\""); - System.out.println("filename: \""+testCase.filename+"\""); - InputStream inputStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("samples/" + testCase.filename); - byte[] data = new byte[inputStream.available()]; - - for(Map.Entry validationEntry: testCase.validations.getAsJsonObject().entrySet()) { - String validationName = validationEntry.getKey(); - JsonObject validation = validationEntry.getValue().getAsJsonObject(); - - JsonObject expected_result = validation.getAsJsonObject("result"); - String[] authorizer_facts = validation.getAsJsonPrimitive("authorizer_code").getAsString().split(";"); - Either res = Try.of(() -> { - inputStream.read(data); - Biscuit token = Biscuit.from_bytes(data, publicKey); - assertArrayEquals(token.serialize(), data); - - List allBlocks = new ArrayList<>(); - allBlocks.add(token.authority); - allBlocks.addAll(token.blocks); - - compareBlocks(privateKey, testCase.token, token); - - byte[] ser_block_authority = token.authority.to_bytes().get(); - System.out.println(Arrays.toString(ser_block_authority)); - System.out.println(Arrays.toString(token.serializedBiscuit.authority.block)); - org.biscuitsec.biscuit.token.Block deser_block_authority = from_bytes(ser_block_authority, token.authority.externalKey).get(); - assertEquals(token.authority.print(token.symbols), deser_block_authority.print(token.symbols)); - assert(Arrays.equals(ser_block_authority, token.serializedBiscuit.authority.block)); - - for(int i = 0; i < token.blocks.size() - 1; i++) { - org.biscuitsec.biscuit.token.Block block = token.blocks.get(i); - SignedBlock signed_block = token.serializedBiscuit.blocks.get(i); - byte[] ser_block = block.to_bytes().get(); - org.biscuitsec.biscuit.token.Block deser_block = from_bytes(ser_block,block.externalKey).get(); - assertEquals(block.print(token.symbols), deser_block.print(token.symbols)); - assert(Arrays.equals(ser_block, signed_block.block)); - } + Biscuit newSampleToken; + if (!sampleToken.isDefined()) { + org.biscuitsec.biscuit.token.builder.Biscuit builder = + new org.biscuitsec.biscuit.token.builder.Biscuit( + new SecureRandom(), root, Option.none(), outputSample.get()); + newSampleToken = builder.build(); + } else { + Biscuit s = sampleToken.get(); + newSampleToken = s.attenuate(outputSample.get(), Schema.PublicKey.Algorithm.Ed25519); + } - List revocationIds = token.revocation_identifiers(); - JsonArray validationRevocationIds = validation.getAsJsonArray("revocation_ids"); - assertEquals(revocationIds.size(), validationRevocationIds.size()); - for(int i = 0; i < revocationIds.size(); i++) { - assertEquals(validationRevocationIds.get(i).getAsString(), revocationIds.get(i).toHex()); - } + org.biscuitsec.biscuit.token.Block generatedSampleBlock; + if (!sampleToken.isDefined()) { + generatedSampleBlock = newSampleToken.authority; + } else { + generatedSampleBlock = newSampleToken.blocks.get((int) sampleBlockIndex - 1); + } - // TODO Add check of the token - - Authorizer authorizer = token.authorizer(); - System.out.println(token.print()); - for (String f : authorizer_facts) { - f = f.trim(); - if (f.length() > 0) { - if (f.startsWith("check if") || f.startsWith("check all")) { - authorizer.add_check(f); - } else if (f.startsWith("allow if") || f.startsWith("deny if")) { - authorizer.add_policy(f); - } else if (f.startsWith("revocation_id")) { + System.out.println("generated block: "); + System.out.println(generatedSampleBlock.print(newSampleToken.symbolTable)); + System.out.println("deserialized block: "); + System.out.println(tokenBlock.print(newSampleToken.symbolTable)); + + SymbolTable tokenBlockSymbols = tokenSymbols; + SymbolTable generatedBlockSymbols = newSampleToken.symbolTable; + assertEquals( + generatedSampleBlock.printCode(generatedBlockSymbols), + tokenBlock.printCode(tokenBlockSymbols)); + + /* FIXME: to generate the same sample block, + we need the samples to provide the external private key + assertEquals(generatedSampleBlock, tokenBlock); + assertArrayEquals(generatedSampleBlock.to_bytes().get(), tokenBlock.to_bytes().get()); + */ + + return newSampleToken; + } + + DynamicTest processTestcase( + final TestCase testCase, final PublicKey publicKey, final KeyPair privateKey) { + return DynamicTest.dynamicTest( + testCase.title + ": " + testCase.filename, + () -> { + System.out.println("Testcase name: \"" + testCase.title + "\""); + System.out.println("filename: \"" + testCase.filename + "\""); + InputStream inputStream = + Thread.currentThread() + .getContextClassLoader() + .getResourceAsStream("samples/" + testCase.filename); + byte[] data = new byte[inputStream.available()]; + + for (Map.Entry validationEntry : + testCase.validations.getAsJsonObject().entrySet()) { + String validationName = validationEntry.getKey(); + JsonObject validation = validationEntry.getValue().getAsJsonObject(); + + JsonObject expectedResult = validation.getAsJsonObject("result"); + String[] authorizerFacts = + validation.getAsJsonPrimitive("authorizer_code").getAsString().split(";"); + Either res = + Try.of( + () -> { + inputStream.read(data); + Biscuit token = Biscuit.fromBytes(data, publicKey); + assertArrayEquals(token.serialize(), data); + + List allBlocks = new ArrayList<>(); + allBlocks.add(token.authority); + allBlocks.addAll(token.blocks); + + compareBlocks(privateKey, testCase.token, token); + + byte[] serBlockAuthority = token.authority.toBytes().get(); + System.out.println(Arrays.toString(serBlockAuthority)); + System.out.println( + Arrays.toString(token.serializedBiscuit.getAuthority().getBlock())); + org.biscuitsec.biscuit.token.Block deserBlockAuthority = + fromBytes(serBlockAuthority, token.authority.getExternalKey()) + .get(); + assertEquals( + token.authority.print(token.symbolTable), + deserBlockAuthority.print(token.symbolTable)); + assert (Arrays.equals( + serBlockAuthority, + token.serializedBiscuit.getAuthority().getBlock())); + + for (int i = 0; i < token.blocks.size() - 1; i++) { + org.biscuitsec.biscuit.token.Block block = token.blocks.get(i); + SignedBlock signedBlock = token.serializedBiscuit.getBlocks().get(i); + byte[] serBlock = block.toBytes().get(); + org.biscuitsec.biscuit.token.Block deserBlock = + fromBytes(serBlock, block.getExternalKey()).get(); + assertEquals( + block.print(token.symbolTable), deserBlock.print(token.symbolTable)); + assert (Arrays.equals(serBlock, signedBlock.getBlock())); + } + + List revocationIds = token.revocationIdentifiers(); + JsonArray validationRevocationIds = + validation.getAsJsonArray("revocation_ids"); + assertEquals(revocationIds.size(), validationRevocationIds.size()); + for (int i = 0; i < revocationIds.size(); i++) { + assertEquals( + validationRevocationIds.get(i).getAsString(), + revocationIds.get(i).toHex()); + } + + // TODO Add check of the token + + Authorizer authorizer = token.authorizer(); + System.out.println(token.print()); + for (String f : authorizerFacts) { + f = f.trim(); + if (!f.isEmpty()) { + if (f.startsWith("check if") || f.startsWith("check all")) { + authorizer.addCheck(f); + } else if (f.startsWith("allow if") || f.startsWith("deny if")) { + authorizer.addPolicy(f); + } else if (f.startsWith("revocation_id")) { // do nothing - } else { - authorizer.add_fact(f); + } else { + authorizer.addFact(f); + } + } + } + System.out.println(authorizer.formatWorld()); + try { + Long authorizeResult = authorizer.authorize(runLimits); + + if (validation.has("world") && !validation.get("world").isJsonNull()) { + World world = + new Gson() + .fromJson( + validation.get("world").getAsJsonObject(), World.class); + world.fixOrigin(); + + World authorizerWorld = new World(authorizer); + assertEquals(world.factMap(), authorizerWorld.factMap()); + assertEquals(world.rules, authorizerWorld.rules); + assertEquals(world.checks, authorizerWorld.checks); + assertEquals(world.policies, authorizerWorld.policies); } - } - } - System.out.println(authorizer.print_world()); - try { - Long authorizeResult = authorizer.authorize(runLimits); - - if(validation.has("world") && !validation.get("world").isJsonNull()) { - World world = new Gson().fromJson(validation.get("world").getAsJsonObject(), World.class); - world.fixOrigin(); - - World authorizerWorld = new World(authorizer); - assertEquals(world.factMap(), authorizerWorld.factMap()); - assertEquals(world.rules, authorizerWorld.rules); - assertEquals(world.checks, authorizerWorld.checks); - assertEquals(world.policies, authorizerWorld.policies); - - - } - - return authorizeResult; - } catch (Exception e) { - - if(validation.has("world") && !validation.get("world").isJsonNull()) { - World world = new Gson().fromJson(validation.get("world").getAsJsonObject(), World.class); - world.fixOrigin(); - - World authorizerWorld = new World(authorizer); - assertEquals(world.factMap(), authorizerWorld.factMap()); - assertEquals(world.rules, authorizerWorld.rules); - assertEquals(world.checks, authorizerWorld.checks); - assertEquals(world.policies, authorizerWorld.policies); - } - throw e; - } - }).toEither(); + return authorizeResult; + } catch (Exception e) { + + if (validation.has("world") && !validation.get("world").isJsonNull()) { + World world = + new Gson() + .fromJson( + validation.get("world").getAsJsonObject(), World.class); + world.fixOrigin(); + + World authorizerWorld = new World(authorizer); + assertEquals(world.factMap(), authorizerWorld.factMap()); + assertEquals(world.rules, authorizerWorld.rules); + assertEquals(world.checks, authorizerWorld.checks); + assertEquals(world.policies, authorizerWorld.policies); + } - if(expected_result.has("Ok")) { - if (res.isLeft()) { - System.out.println("validation '"+validationName+"' expected result Ok("+expected_result.getAsJsonPrimitive("Ok").getAsLong()+"), got error"); - throw res.getLeft(); - } else { - assertEquals(expected_result.getAsJsonPrimitive("Ok").getAsLong(), res.get()); - } + throw e; + } + }) + .toEither(); + + if (expectedResult.has("Ok")) { + if (res.isLeft()) { + System.out.println( + "validation '" + + validationName + + "' expected result Ok(" + + expectedResult.getAsJsonPrimitive("Ok").getAsLong() + + "), got error"); + throw res.getLeft(); + } else { + assertEquals(expectedResult.getAsJsonPrimitive("Ok").getAsLong(), res.get()); + } + } else { + if (res.isLeft()) { + if (res.getLeft() instanceof Error) { + Error e = (Error) res.getLeft(); + System.out.println("validation '" + validationName + "' got error: " + e); + JsonElement errJson = e.toJson(); + assertEquals(expectedResult.get("Err"), errJson); } else { - if (res.isLeft()) { - if(res.getLeft() instanceof Error) { - Error e = (Error) res.getLeft(); - System.out.println("validation '"+validationName+"' got error: " + e); - JsonElement err_json = e.toJson(); - assertEquals(expected_result.get("Err"), err_json); - } else { - throw res.getLeft(); - } - } else { - throw new Exception("validation '"+validationName+"' expected result error("+expected_result.get("Err")+"), got success: "+res.get()); - } + throw res.getLeft(); } + } else { + throw new Exception( + "validation '" + + validationName + + "' expected result error(" + + expectedResult.get("Err") + + "), got success: " + + res.get()); + } } + } }); + } + + class Block { + List symbols; + String code; + @SuppressWarnings("checkstyle:MemberName") + List public_keys; + @SuppressWarnings("checkstyle:MemberName") + String external_key; + + public List getSymbols() { + return symbols; } - class Block { - List symbols; - String code; - List public_keys; - String external_key; - - public List getSymbols() { - return symbols; - } - - public void setSymbols(List symbols) { - this.symbols = symbols; - } - - public String getCode() { return code; } - - public void setCode(String code) { this.code = code; } + public void setSymbols(List symbols) { + this.symbols = symbols; + } - public List getPublicKeys() { - return this.public_keys.stream() - .map(pk -> - Parser.publicKey(pk).fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2) - ) - .collect(Collectors.toList()); - } + public String getCode() { + return code; + } - public void setPublicKeys(List publicKeys) { - this.public_keys = publicKeys.stream() - .map(PublicKey::toString) - .collect(Collectors.toList()); - } + public void setCode(String code) { + this.code = code; + } - public Option getExternalKey() { - if (this.external_key != null) { - PublicKey externalKey = Parser.publicKey(this.external_key) - .fold(e -> { throw new IllegalArgumentException(e.toString());}, r -> r._2); - return Option.of(externalKey); - } else { - return Option.none(); - } - } + public List getPublicKeys() { + return this.public_keys.stream() + .map( + pk -> + Parser.publicKey(pk) + .fold( + e -> { + throw new IllegalArgumentException(e.toString()); + }, + r -> r._2)) + .collect(Collectors.toList()); + } - public void setExternalKey(Option externalKey) { - this.external_key = externalKey.map(PublicKey::toString).getOrElse((String) null); - } + public void setPublicKeys(List publicKeys) { + this.public_keys = publicKeys.stream().map(PublicKey::toString).collect(Collectors.toList()); } - class TestCase { - String title; + public Option getExternalKey() { + if (this.external_key != null) { + PublicKey externalKey = + Parser.publicKey(this.external_key) + .fold( + e -> { + throw new IllegalArgumentException(e.toString()); + }, + r -> r._2); + return Option.of(externalKey); + } else { + return Option.none(); + } + } - public String getTitle() { - return title; - } + public void setExternalKey(Option externalKey) { + this.external_key = externalKey.map(PublicKey::toString).getOrElse((String) null); + } + } - public void setTitle(String title) { - this.title = title; - } + class TestCase { + String title; - String filename; - List token; - JsonElement validations; + public String getTitle() { + return title; + } - public String getFilename() { - return filename; - } + public void setTitle(String title) { + this.title = title; + } - public void setFilename(String filename) { - this.filename = filename; - } + String filename; + List token; + JsonElement validations; - public List getToken() { - return token; - } + public String getFilename() { + return filename; + } - public void setTokens(List token) { - this.token = token; - } + public void setFilename(String filename) { + this.filename = filename; + } - public JsonElement getValidations() { - return validations; - } + public List getToken() { + return token; + } - public void setValidations(JsonElement validations) { - this.validations = validations; - } + public void setTokens(List token) { + this.token = token; } - class Sample { - String root_private_key; + public JsonElement getValidations() { + return validations; + } - public String getRoot_public_key() { - return root_public_key; - } + public void setValidations(JsonElement validations) { + this.validations = validations; + } + } - public void setRoot_public_key(String root_public_key) { - this.root_public_key = root_public_key; - } + class Sample { + @SuppressWarnings("checkstyle:MemberName") + String root_private_key; - String root_public_key; - List testcases; + @SuppressWarnings("checkstyle:MethodName") + public String getRoot_public_key() { + return root_public_key; + } - public String getRoot_private_key() { - return root_private_key; - } + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void setRoot_public_key(String root_public_key) { + this.root_public_key = root_public_key; + } - public void setRoot_private_key(String root_private_key) { - this.root_private_key = root_private_key; - } + @SuppressWarnings("checkstyle:MemberName") + String root_public_key; + List testcases; - public List getTestcases() { - return testcases; - } + @SuppressWarnings("checkstyle:MethodName") + public String getRoot_private_key() { + return root_private_key; + } - public void setTestcases(List testcases) { - this.testcases = testcases; - } + @SuppressWarnings({"checkstyle:MethodName", "checkstyle:ParameterName"}) + public void setRoot_private_key(String root_private_key) { + this.root_private_key = root_private_key; } - class World { - List facts; - List rules; - List checks; - List policies; + public List getTestcases() { + return testcases; + } - public World(List facts, List rules, List checks, List policies) { - this.facts = facts; - this.rules = rules; - this.checks = checks; - this.policies = policies; - } + public void setTestcases(List testcases) { + this.testcases = testcases; + } + } + + class World { + List facts; + List rules; + List checks; + List policies; + + public World( + List facts, List rules, List checks, List policies) { + this.facts = facts; + this.rules = rules; + this.checks = checks; + this.policies = policies; + } - public World(Authorizer authorizer) { - this.facts = authorizer.facts().facts().entrySet().stream().map(entry -> { - ArrayList origin = new ArrayList<>(entry.getKey().inner); - Collections.sort(origin); - ArrayList facts = new ArrayList<>(entry.getValue().stream() - .map(f -> authorizer.symbols.print_fact(f)).collect(Collectors.toList())); - Collections.sort(facts); - - return new FactSet(origin, facts); - }).collect(Collectors.toList()); - - HashMap> rules = new HashMap<>(); - for(List> l: authorizer.rules().rules.values()) { - for(Tuple2 t: l) { - if (!rules.containsKey(t._1)) { - rules.put(t._1, new ArrayList<>()); + public World(Authorizer authorizer) { + this.facts = + authorizer.getFacts().facts().entrySet().stream() + .map( + entry -> { + ArrayList origin = new ArrayList<>(entry.getKey().blockIds()); + Collections.sort(origin); + ArrayList facts = + new ArrayList<>( + entry.getValue().stream() + .map(f -> authorizer.getSymbolTable().formatFact(f)) + .collect(Collectors.toList())); + Collections.sort(facts); + + return new FactSet(origin, facts); + }) + .collect(Collectors.toList()); + + HashMap> rules = new HashMap<>(); + for (List> l : authorizer.getRules().getRules().values()) { + for (Tuple2 t : l) { + if (!rules.containsKey(t._1)) { + rules.put(t._1, new ArrayList<>()); + } + rules.get(t._1).add(authorizer.getSymbolTable().formatRule(t._2)); + } + } + for (Map.Entry> entry : rules.entrySet()) { + Collections.sort(entry.getValue()); + } + List rulesets = + rules.entrySet().stream() + .map(entry -> new RuleSet(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + Collections.sort(rulesets); + + this.rules = rulesets; + + this.checks = + authorizer.getChecks().stream() + .map( + (Tuple2> t) -> { + List checks1 = + t._2.stream().map(c -> c.toString()).collect(Collectors.toList()); + Collections.sort(checks1); + if (t._1 == null) { + return new CheckSet(checks1); + } else { + return new CheckSet(t._1, checks1); } - rules.get(t._1).add(authorizer.symbols.print_rule(t._2)); - } - } - for(Map.Entry> entry: rules.entrySet()) { - Collections.sort(entry.getValue()); - } - List rulesets = rules.entrySet().stream() - .map(entry -> new RuleSet(entry.getKey(), entry.getValue())) - .collect(Collectors.toList()); - Collections.sort(rulesets); - - this.rules = rulesets; - - this.checks = authorizer.checks().stream() - .map((Tuple2> t) -> { - List checks1 = t._2.stream().map(c -> c.toString()).collect(Collectors.toList()); - Collections.sort(checks1); - if(t._1 == null) { - return new CheckSet(checks1); - } else { - return new CheckSet(t._1, checks1); - } - }).collect(Collectors.toList()); - this.policies = authorizer.policies().stream().map(p -> p.toString()).collect(Collectors.toList()); - Collections.sort(this.rules); - Collections.sort(this.checks); - } - - public void fixOrigin() { - for(FactSet f: this.facts) { - f.fixOrigin(); - } - for(RuleSet r: this.rules) { - r.fixOrigin(); - } - Collections.sort(this.rules); - for(CheckSet c: this.checks) { - c.fixOrigin(); - } - Collections.sort(this.checks); - } - - public HashMap, List> factMap() { - HashMap, List> worldFacts = new HashMap<>(); - for(FactSet f: this.facts) { - worldFacts.put(f.origin, f.facts); - } - - return worldFacts; - } - - @Override - public String toString() { - return "World{\n" + - "facts=" + facts + - ",\nrules=" + rules + - ",\nchecks=" + checks + - ",\npolicies=" + policies + - '}'; - } + }) + .collect(Collectors.toList()); + this.policies = + authorizer.getPolicies().stream().map(p -> p.toString()).collect(Collectors.toList()); + Collections.sort(this.rules); + Collections.sort(this.checks); } - class FactSet { - List origin; - List facts; - - public FactSet(List origin, List facts) { - this.origin = origin; - this.facts = facts; - } - - // JSON cannot represent Long.MAX_VALUE so it is stored as null, fix the origin list - public void fixOrigin() { - for(int i = 0; i < this.origin.size(); i++) { - if (this.origin.get(i) == null) { - this.origin.set(i, Long.MAX_VALUE); - } - } - Collections.sort(this.origin); - } + public void fixOrigin() { + for (FactSet f : this.facts) { + f.fixOrigin(); + } + for (RuleSet r : this.rules) { + r.fixOrigin(); + } + Collections.sort(this.rules); + for (CheckSet c : this.checks) { + c.fixOrigin(); + } + Collections.sort(this.checks); + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public HashMap, List> factMap() { + HashMap, List> worldFacts = new HashMap<>(); + for (FactSet f : this.facts) { + worldFacts.put(f.origin, f.facts); + } - FactSet factSet = (FactSet) o; + return worldFacts; + } - if (!Objects.equals(origin, factSet.origin)) return false; - return Objects.equals(facts, factSet.facts); - } + @Override + public String toString() { + return "World{\n" + + "facts=" + + facts + + ",\nrules=" + + rules + + ",\nchecks=" + + checks + + ",\npolicies=" + + policies + + '}'; + } + } - @Override - public int hashCode() { - int result = origin != null ? origin.hashCode() : 0; - result = 31 * result + (facts != null ? facts.hashCode() : 0); - return result; - } + class FactSet { + List origin; + List facts; - @Override - public String toString() { - return "FactSet{" + - "origin=" + origin + - ", facts=" + facts + - '}'; - } + public FactSet(List origin, List facts) { + this.origin = origin; + this.facts = facts; } - class RuleSet implements Comparable { - Long origin; - List rules; - - public RuleSet(Long origin, List rules) { - this.origin = origin; - this.rules = rules; + // JSON cannot represent Long.MAX_VALUE so it is stored as null, fix the origin list + public void fixOrigin() { + for (int i = 0; i < this.origin.size(); i++) { + if (this.origin.get(i) == null) { + this.origin.set(i, Long.MAX_VALUE); } + } + Collections.sort(this.origin); + } - public void fixOrigin() { - if (this.origin == null || this.origin == -1) { - this.origin = Long.MAX_VALUE; - } - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FactSet factSet = (FactSet) o; + + if (!Objects.equals(origin, factSet.origin)) { + return false; + } + return Objects.equals(facts, factSet.facts); + } - @Override - public int compareTo(RuleSet ruleSet) { - // we only compare origin to sort the list of rulesets - // there's only one of each origin so we don't need to compare the list of rules - if(this.origin == null) { - return -1; - } else if (ruleSet.origin == null) { - return 1; - } else { - return this.origin.compareTo(ruleSet.origin); - } - } + @Override + public int hashCode() { + int result = origin != null ? origin.hashCode() : 0; + result = 31 * result + (facts != null ? facts.hashCode() : 0); + return result; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + @Override + public String toString() { + return "FactSet{" + "origin=" + origin + ", facts=" + facts + '}'; + } + } - RuleSet ruleSet = (RuleSet) o; + class RuleSet implements Comparable { + Long origin; + List rules; - if (!Objects.equals(origin, ruleSet.origin)) return false; - return Objects.equals(rules, ruleSet.rules); - } + public RuleSet(Long origin, List rules) { + this.origin = origin; + this.rules = rules; + } - @Override - public int hashCode() { - int result = origin != null ? origin.hashCode() : 0; - result = 31 * result + (rules != null ? rules.hashCode() : 0); - return result; - } + public void fixOrigin() { + if (this.origin == null || this.origin == -1) { + this.origin = Long.MAX_VALUE; + } + } - @Override - public String toString() { - return "RuleSet{" + - "origin=" + origin + - ", rules=" + rules + - '}'; - } + @Override + public int compareTo(RuleSet ruleSet) { + // we only compare origin to sort the list of rulesets + // there's only one of each origin so we don't need to compare the list of rules + if (this.origin == null) { + return -1; + } else if (ruleSet.origin == null) { + return 1; + } else { + return this.origin.compareTo(ruleSet.origin); + } } - class CheckSet implements Comparable { - Long origin; - List checks; + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + RuleSet ruleSet = (RuleSet) o; + + if (!Objects.equals(origin, ruleSet.origin)) { + return false; + } + return Objects.equals(rules, ruleSet.rules); + } - public CheckSet(Long origin, List checks) { - this.origin = origin; - this.checks = checks; - } + @Override + public int hashCode() { + int result = origin != null ? origin.hashCode() : 0; + result = 31 * result + (rules != null ? rules.hashCode() : 0); + return result; + } - public CheckSet(List checks) { - this.origin = null; - this.checks = checks; - } + @Override + public String toString() { + return "RuleSet{" + "origin=" + origin + ", rules=" + rules + '}'; + } + } - public void fixOrigin() { - if (this.origin == null || this.origin == -1) { - this.origin = Long.MAX_VALUE; - } - } + class CheckSet implements Comparable { + Long origin; + List checks; - @Override - public int compareTo(CheckSet checkSet) { - // we only compare origin to sort the list of checksets - // there's only one of each origin so we don't need to compare the list of rules - if(this.origin == null) { - return -1; - } else if (checkSet.origin == null) { - return 1; - } else { - return this.origin.compareTo(checkSet.origin); - } - } + public CheckSet(Long origin, List checks) { + this.origin = origin; + this.checks = checks; + } - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public CheckSet(List checks) { + this.origin = null; + this.checks = checks; + } - CheckSet checkSet = (CheckSet) o; + public void fixOrigin() { + if (this.origin == null || this.origin == -1) { + this.origin = Long.MAX_VALUE; + } + } - if (!Objects.equals(origin, checkSet.origin)) return false; - return Objects.equals(checks, checkSet.checks); - } + @Override + public int compareTo(CheckSet checkSet) { + // we only compare origin to sort the list of checksets + // there's only one of each origin so we don't need to compare the list of rules + if (this.origin == null) { + return -1; + } else if (checkSet.origin == null) { + return 1; + } else { + return this.origin.compareTo(checkSet.origin); + } + } - @Override - public int hashCode() { - int result = origin != null ? origin.hashCode() : 0; - result = 31 * result + (checks != null ? checks.hashCode() : 0); - return result; - } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + CheckSet checkSet = (CheckSet) o; + + if (!Objects.equals(origin, checkSet.origin)) { + return false; + } + return Objects.equals(checks, checkSet.checks); + } - @Override - public String toString() { - return "CheckSet{" + - "origin=" + origin + - ", checks=" + checks + - '}'; - } + @Override + public int hashCode() { + int result = origin != null ? origin.hashCode() : 0; + result = 31 * result + (checks != null ? checks.hashCode() : 0); + return result; } + @Override + public String toString() { + return "CheckSet{" + "origin=" + origin + ", checks=" + checks + '}'; + } + } } diff --git a/src/test/java/org/biscuitsec/biscuit/token/ThirdPartyTest.java b/src/test/java/org/biscuitsec/biscuit/token/ThirdPartyTest.java index 7f520167..92621510 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/ThirdPartyTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/ThirdPartyTest.java @@ -1,14 +1,8 @@ package org.biscuitsec.biscuit.token; -import biscuit.format.schema.Schema; -import org.biscuitsec.biscuit.crypto.KeyPair; -import org.biscuitsec.biscuit.datalog.RunLimits; -import org.biscuitsec.biscuit.error.Error; -import org.biscuitsec.biscuit.error.FailedCheck; -import org.biscuitsec.biscuit.error.LogicError; -import org.biscuitsec.biscuit.token.builder.Block; -import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; +import biscuit.format.schema.Schema; import java.io.IOException; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; @@ -16,208 +10,237 @@ import java.security.SignatureException; import java.time.Duration; import java.util.Arrays; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.biscuitsec.biscuit.crypto.KeyPair; +import org.biscuitsec.biscuit.datalog.RunLimits; +import org.biscuitsec.biscuit.error.Error; +import org.biscuitsec.biscuit.error.FailedCheck; +import org.biscuitsec.biscuit.error.LogicError; +import org.biscuitsec.biscuit.token.builder.Block; +import org.junit.jupiter.api.Test; public class ThirdPartyTest { - @Test - public void testRoundTrip() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, CloneNotSupportedException, Error, IOException { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - System.out.println("preparing the authority block"); - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - KeyPair external = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - System.out.println("external: ed25519/"+external.public_key().toHex()); - - Block authority_builder = new Block(); - authority_builder.add_fact("right(\"read\")"); - authority_builder.add_check("check if group(\"admin\") trusting ed25519/"+external.public_key().toHex()); - - Biscuit b1 = Biscuit.make(rng, root, authority_builder.build()); - ThirdPartyBlockRequest request = b1.thirdPartyRequest(); - byte[] reqb = request.toBytes(); - ThirdPartyBlockRequest reqdeser = ThirdPartyBlockRequest.fromBytes(reqb); - assertEquals(request, reqdeser); - - Block builder = new Block(); - builder.add_fact("group(\"admin\")"); - builder.add_check("check if resource(\"file1\")"); - - ThirdPartyBlockContents blockResponse = request.createBlock(external, builder).get(); - byte[] resb = blockResponse.toBytes(); - ThirdPartyBlockContents resdeser = ThirdPartyBlockContents.fromBytes(resb); - assertEquals(blockResponse, resdeser); - - Biscuit b2 = b1.appendThirdPartyBlock(external.public_key(), blockResponse); - - byte[] data = b2.serialize(); - Biscuit deser = Biscuit.from_bytes(data, root.public_key()); - assertEquals(b2.print(), deser.print()); - - System.out.println("will check the token for resource=file1"); - Authorizer authorizer = deser.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - - System.out.println("will check the token for resource=file2"); - Authorizer authorizer2 = deser.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_policy("allow if true"); - - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource(\"file1\")") - ))), - e); - } + @Test + public void testRoundTrip() + throws NoSuchAlgorithmException, + SignatureException, + InvalidKeyException, + CloneNotSupportedException, + Error, + IOException { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); + + System.out.println("preparing the authority block"); + + final KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + final KeyPair external = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + System.out.println("external: ed25519/" + external.getPublicKey().toHex()); + + Block authorityBuilder = new Block(); + authorityBuilder.addFact("right(\"read\")"); + authorityBuilder.addCheck( + "check if group(\"admin\") trusting ed25519/" + external.getPublicKey().toHex()); + + Biscuit b1 = Biscuit.make(rng, root, authorityBuilder.build()); + ThirdPartyBlockRequest request = b1.thirdPartyRequest(); + byte[] reqb = request.toBytes(); + ThirdPartyBlockRequest reqdeser = ThirdPartyBlockRequest.fromBytes(reqb); + assertEquals(request, reqdeser); + + Block builder = new Block(); + builder.addFact("group(\"admin\")"); + builder.addCheck("check if resource(\"file1\")"); + + ThirdPartyBlockContents blockResponse = request.createBlock(external, builder).get(); + byte[] resb = blockResponse.toBytes(); + ThirdPartyBlockContents resdeser = ThirdPartyBlockContents.fromBytes(resb); + assertEquals(blockResponse, resdeser); + + Biscuit b2 = b1.appendThirdPartyBlock(external.getPublicKey(), blockResponse); + + byte[] data = b2.serialize(); + Biscuit deser = Biscuit.fromBytes(data, root.getPublicKey()); + assertEquals(b2.print(), deser.print()); + + System.out.println("will check the token for resource=file1"); + Authorizer authorizer = deser.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + + System.out.println("will check the token for resource=file2"); + Authorizer authorizer2 = deser.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addPolicy("allow if true"); + + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock(1, 0, "check if resource(\"file1\")")))), + e); } - - @Test - public void testPublicKeyInterning() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, CloneNotSupportedException, Error { - // this makes a deterministic RNG - SecureRandom rng = SecureRandom.getInstance("SHA1PRNG"); - byte[] seed = {0, 0, 0, 0}; - rng.setSeed(seed); - - System.out.println("preparing the authority block"); - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - KeyPair external1 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - KeyPair external2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - KeyPair external3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - //System.out.println("external: ed25519/"+external.public_key().toHex()); - - Block authority_builder = new Block(); - authority_builder.add_fact("right(\"read\")"); - authority_builder.add_check("check if first(\"admin\") trusting ed25519/"+external1.public_key().toHex()); - - org.biscuitsec.biscuit.token.Block authority_block = authority_builder.build(); - System.out.println(authority_block); - Biscuit b1 = Biscuit.make(rng, root, authority_block); - System.out.println("TOKEN: "+b1.print()); - - ThirdPartyBlockRequest request1 = b1.thirdPartyRequest(); - Block builder = new Block(); - builder.add_fact("first(\"admin\")"); - builder.add_fact("second(\"A\")"); - builder.add_check("check if third(3) trusting ed25519/"+external2.public_key().toHex()); - ThirdPartyBlockContents blockResponse = request1.createBlock(external1, builder).get(); - Biscuit b2 = b1.appendThirdPartyBlock(external1.public_key(), blockResponse); - byte[] data = b2.serialize(); - Biscuit deser2 = Biscuit.from_bytes(data, root.public_key()); - assertEquals(b2.print(), deser2.print()); - System.out.println("TOKEN: "+deser2.print()); - - ThirdPartyBlockRequest request2 = deser2.thirdPartyRequest(); - Block builder2 = new Block(); - builder2.add_fact("third(3)"); - builder2.add_check("check if fourth(1) trusting ed25519/"+external3.public_key().toHex()+", ed25519/"+external1.public_key().toHex()); - ThirdPartyBlockContents blockResponse2 = request2.createBlock(external2, builder2).get(); - Biscuit b3 = deser2.appendThirdPartyBlock(external2.public_key(), blockResponse2); - byte[] data2 = b3.serialize(); - Biscuit deser3 = Biscuit.from_bytes(data2, root.public_key()); - assertEquals(b3.print(), deser3.print()); - System.out.println("TOKEN: "+deser3.print()); - - - ThirdPartyBlockRequest request3 = deser3.thirdPartyRequest(); - Block builder3 = new Block(); - builder3.add_fact("fourth(1)"); - builder3.add_check("check if resource(\"file1\")"); - ThirdPartyBlockContents blockResponse3 = request3.createBlock(external1, builder3).get(); - Biscuit b4 = deser3.appendThirdPartyBlock(external1.public_key(), blockResponse3); - byte[] data3 = b4.serialize(); - Biscuit deser4 = Biscuit.from_bytes(data3, root.public_key()); - assertEquals(b4.print(), deser4.print()); - System.out.println("TOKEN: "+deser4.print()); - - - System.out.println("will check the token for resource=file1"); - Authorizer authorizer = deser4.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_policy("allow if true"); - System.out.println("Authorizer world:\n"+authorizer.print_world()); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - - System.out.println("will check the token for resource=file2"); - Authorizer authorizer2 = deser4.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_policy("allow if true"); - System.out.println("Authorizer world 2:\n"+authorizer2.print_world()); - - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(3, 0, "check if resource(\"file1\")") - ))), - e); - } + } + + @Test + public void testPublicKeyInterning() + throws NoSuchAlgorithmException, + SignatureException, + InvalidKeyException, + CloneNotSupportedException, + Error { + // this makes a deterministic RNG + SecureRandom rng = SecureRandom.getInstance("SHA1PRNG"); + byte[] seed = {0, 0, 0, 0}; + rng.setSeed(seed); + + System.out.println("preparing the authority block"); + + final KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + final KeyPair external1 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + final KeyPair external2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + final KeyPair external3 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + // System.out.println("external: ed25519/" + external.public_key().toHex()); + + Block authorityBuilder = new Block(); + authorityBuilder.addFact("right(\"read\")"); + authorityBuilder.addCheck( + "check if first(\"admin\") trusting ed25519/" + external1.getPublicKey().toHex()); + + org.biscuitsec.biscuit.token.Block authorityBlock = authorityBuilder.build(); + System.out.println(authorityBlock); + Biscuit b1 = Biscuit.make(rng, root, authorityBlock); + System.out.println("TOKEN: " + b1.print()); + + ThirdPartyBlockRequest request1 = b1.thirdPartyRequest(); + Block builder = new Block(); + builder.addFact("first(\"admin\")"); + builder.addFact("second(\"A\")"); + builder.addCheck("check if third(3) trusting ed25519/" + external2.getPublicKey().toHex()); + ThirdPartyBlockContents blockResponse = request1.createBlock(external1, builder).get(); + Biscuit b2 = b1.appendThirdPartyBlock(external1.getPublicKey(), blockResponse); + byte[] data = b2.serialize(); + Biscuit deser2 = Biscuit.fromBytes(data, root.getPublicKey()); + assertEquals(b2.print(), deser2.print()); + System.out.println("TOKEN: " + deser2.print()); + + ThirdPartyBlockRequest request2 = deser2.thirdPartyRequest(); + Block builder2 = new Block(); + builder2.addFact("third(3)"); + builder2.addCheck( + "check if fourth(1) trusting ed25519/" + + external3.getPublicKey().toHex() + + ", ed25519/" + + external1.getPublicKey().toHex()); + ThirdPartyBlockContents blockResponse2 = request2.createBlock(external2, builder2).get(); + Biscuit b3 = deser2.appendThirdPartyBlock(external2.getPublicKey(), blockResponse2); + byte[] data2 = b3.serialize(); + Biscuit deser3 = Biscuit.fromBytes(data2, root.getPublicKey()); + assertEquals(b3.print(), deser3.print()); + System.out.println("TOKEN: " + deser3.print()); + + ThirdPartyBlockRequest request3 = deser3.thirdPartyRequest(); + Block builder3 = new Block(); + builder3.addFact("fourth(1)"); + builder3.addCheck("check if resource(\"file1\")"); + ThirdPartyBlockContents blockResponse3 = request3.createBlock(external1, builder3).get(); + Biscuit b4 = deser3.appendThirdPartyBlock(external1.getPublicKey(), blockResponse3); + byte[] data3 = b4.serialize(); + Biscuit deser4 = Biscuit.fromBytes(data3, root.getPublicKey()); + assertEquals(b4.print(), deser4.print()); + System.out.println("TOKEN: " + deser4.print()); + + System.out.println("will check the token for resource=file1"); + Authorizer authorizer = deser4.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addPolicy("allow if true"); + System.out.println("Authorizer world:\n" + authorizer.formatWorld()); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + + System.out.println("will check the token for resource=file2"); + Authorizer authorizer2 = deser4.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addPolicy("allow if true"); + System.out.println("Authorizer world 2:\n" + authorizer2.formatWorld()); + + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock(3, 0, "check if resource(\"file1\")")))), + e); } - - @Test - public void testReusedSymbols() throws NoSuchAlgorithmException, SignatureException, InvalidKeyException, CloneNotSupportedException, Error { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); - - System.out.println("preparing the authority block"); - - KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - KeyPair external = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - System.out.println("external: ed25519/"+external.public_key().toHex()); - - Block authority_builder = new Block(); - authority_builder.add_fact("right(\"read\")"); - authority_builder.add_check("check if group(\"admin\") trusting ed25519/"+external.public_key().toHex()); - - Biscuit b1 = Biscuit.make(rng, root, authority_builder.build()); - ThirdPartyBlockRequest request = b1.thirdPartyRequest(); - Block builder = new Block(); - builder.add_fact("group(\"admin\")"); - builder.add_fact("resource(\"file2\")"); - builder.add_check("check if resource(\"file1\")"); - builder.add_check("check if right(\"read\")"); - - ThirdPartyBlockContents blockResponse = request.createBlock(external, builder).get(); - Biscuit b2 = b1.appendThirdPartyBlock(external.public_key(), blockResponse); - - byte[] data = b2.serialize(); - Biscuit deser = Biscuit.from_bytes(data, root.public_key()); - assertEquals(b2.print(), deser.print()); - - System.out.println("will check the token for resource=file1"); - Authorizer authorizer = deser.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - System.out.println("Authorizer world:\n"+authorizer.print_world()); - - - System.out.println("will check the token for resource=file2"); - Authorizer authorizer2 = deser.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_policy("allow if true"); - - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource(\"file1\")") - ))), - e); - } + } + + @Test + public void testReusedSymbols() + throws NoSuchAlgorithmException, + SignatureException, + InvalidKeyException, + CloneNotSupportedException, + Error { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); + + System.out.println("preparing the authority block"); + + final KeyPair root = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + final KeyPair external = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + System.out.println("external: ed25519/" + external.getPublicKey().toHex()); + + Block authorityBuilder = new Block(); + authorityBuilder.addFact("right(\"read\")"); + authorityBuilder.addCheck( + "check if group(\"admin\") trusting ed25519/" + external.getPublicKey().toHex()); + + Biscuit b1 = Biscuit.make(rng, root, authorityBuilder.build()); + ThirdPartyBlockRequest request = b1.thirdPartyRequest(); + Block builder = new Block(); + builder.addFact("group(\"admin\")"); + builder.addFact("resource(\"file2\")"); + builder.addCheck("check if resource(\"file1\")"); + builder.addCheck("check if right(\"read\")"); + + ThirdPartyBlockContents blockResponse = request.createBlock(external, builder).get(); + Biscuit b2 = b1.appendThirdPartyBlock(external.getPublicKey(), blockResponse); + + byte[] data = b2.serialize(); + Biscuit deser = Biscuit.fromBytes(data, root.getPublicKey()); + assertEquals(b2.print(), deser.print()); + + System.out.println("will check the token for resource=file1"); + Authorizer authorizer = deser.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + System.out.println("Authorizer world:\n" + authorizer.formatWorld()); + + System.out.println("will check the token for resource=file2"); + Authorizer authorizer2 = deser.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addPolicy("allow if true"); + + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock(1, 0, "check if resource(\"file1\")")))), + e); } + } } - diff --git a/src/test/java/org/biscuitsec/biscuit/token/UnverifiedBiscuitTest.java b/src/test/java/org/biscuitsec/biscuit/token/UnverifiedBiscuitTest.java index 323365bc..2c24e368 100644 --- a/src/test/java/org/biscuitsec/biscuit/token/UnverifiedBiscuitTest.java +++ b/src/test/java/org/biscuitsec/biscuit/token/UnverifiedBiscuitTest.java @@ -1,9 +1,17 @@ package org.biscuitsec.biscuit.token; +import static org.junit.jupiter.api.Assertions.assertEquals; + import biscuit.format.schema.Schema; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.SignatureException; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; import org.biscuitsec.biscuit.crypto.KeyPair; import org.biscuitsec.biscuit.datalog.RunLimits; -import org.biscuitsec.biscuit.datalog.SymbolTable; import org.biscuitsec.biscuit.error.Error; import org.biscuitsec.biscuit.error.FailedCheck; import org.biscuitsec.biscuit.error.LogicError; @@ -11,140 +19,136 @@ import org.biscuitsec.biscuit.token.builder.Utils; import org.junit.jupiter.api.Test; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; -import java.security.SecureRandom; -import java.security.SignatureException; -import java.time.Duration; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - public class UnverifiedBiscuitTest { - @Test - public void testBasic() throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { - byte[] seed = {0, 0, 0, 0}; - SecureRandom rng = new SecureRandom(seed); + @Test + public void testBasic() + throws Error, NoSuchAlgorithmException, SignatureException, InvalidKeyException { + byte[] seed = {0, 0, 0, 0}; + SecureRandom rng = new SecureRandom(seed); - System.out.println("preparing the authority block, block0"); + System.out.println("preparing the authority block, block0"); - KeyPair keypair0 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair0 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - // org.biscuitsec.biscuit.token.builder.Block block0 = new org.biscuitsec.biscuit.token.builder.Block(0); - org.biscuitsec.biscuit.token.builder.Biscuit block0 = Biscuit.builder(rng, keypair0); - block0.add_authority_fact(Utils.fact("right", List.of(Utils.s("file1"), Utils.s("read")))); - block0.add_authority_fact(Utils.fact("right", List.of(Utils.s("file2"), Utils.s("read")))); - block0.add_authority_fact(Utils.fact("right", List.of(Utils.s("file1"), Utils.s("write")))); + // org.biscuitsec.biscuit.token.builder.Block block0 = new + // org.biscuitsec.biscuit.token.builder.Block(0); + org.biscuitsec.biscuit.token.builder.Biscuit block0 = Biscuit.builder(rng, keypair0); + block0.addAuthorityFact(Utils.fact("right", List.of(Utils.str("file1"), Utils.str("read")))); + block0.addAuthorityFact(Utils.fact("right", List.of(Utils.str("file2"), Utils.str("read")))); + block0.addAuthorityFact(Utils.fact("right", List.of(Utils.str("file1"), Utils.str("write")))); + Biscuit biscuit0 = block0.build(); - Biscuit biscuit0 = block0.build(); + System.out.println(biscuit0.print()); + System.out.println("serializing the first token"); - System.out.println(biscuit0.print()); - System.out.println("serializing the first token"); + String data = biscuit0.serializeBase64Url(); - String data = biscuit0.serialize_b64url(); + System.out.print("data len: "); + System.out.println(data.length()); + System.out.println(data); - System.out.print("data len: "); - System.out.println(data.length()); - System.out.println(data); + System.out.println("deserializing the first token"); + UnverifiedBiscuit deser0 = UnverifiedBiscuit.fromBase64Url(data); + System.out.println(deser0.print()); - System.out.println("deserializing the first token"); - UnverifiedBiscuit deser0 = UnverifiedBiscuit.from_b64url(data); - System.out.println(deser0.print()); + // SECOND BLOCK + System.out.println("preparing the second block"); - // SECOND BLOCK - System.out.println("preparing the second block"); - - KeyPair keypair1 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - org.biscuitsec.biscuit.token.builder.Block block1 = deser0.create_block(); - block1.add_check(Utils.check(Utils.rule( + KeyPair keypair1 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + org.biscuitsec.biscuit.token.builder.Block block1 = deser0.createBlock(); + block1.addCheck( + Utils.check( + Utils.rule( "caveat1", List.of(Utils.var("resource")), List.of( - Utils.pred("resource", List.of(Utils.var("resource"))), - Utils.pred("operation", List.of(Utils.s("read"))), - Utils.pred("right", List.of(Utils.var("resource"), Utils.s("read"))) - ) - ))); - UnverifiedBiscuit unverifiedBiscuit1 = deser0.attenuate(rng, keypair1, block1.build()); + Utils.pred("resource", List.of(Utils.var("resource"))), + Utils.pred("operation", List.of(Utils.str("read"))), + Utils.pred("right", List.of(Utils.var("resource"), Utils.str("read"))))))); + UnverifiedBiscuit unverifiedBiscuit1 = deser0.attenuate(rng, keypair1, block1.build()); - System.out.println(unverifiedBiscuit1.print()); + System.out.println(unverifiedBiscuit1.print()); - System.out.println("serializing the second token"); + System.out.println("serializing the second token"); - String data1 = unverifiedBiscuit1.serialize_b64url(); + String data1 = unverifiedBiscuit1.serializeBase64Url(); - System.out.print("data len: "); - System.out.println(data1.length()); - System.out.println(data1); + System.out.print("data len: "); + System.out.println(data1.length()); + System.out.println(data1); - System.out.println("deserializing the second token"); - UnverifiedBiscuit deser1 = UnverifiedBiscuit.from_b64url(data1); + System.out.println("deserializing the second token"); + UnverifiedBiscuit deser1 = UnverifiedBiscuit.fromBase64Url(data1); - System.out.println(deser1.print()); + System.out.println(deser1.print()); - // THIRD BLOCK - System.out.println("preparing the third block"); + // THIRD BLOCK + System.out.println("preparing the third block"); - KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); + KeyPair keypair2 = KeyPair.generate(Schema.PublicKey.Algorithm.Ed25519, rng); - Block block2 = unverifiedBiscuit1.create_block(); - block2.add_check(Utils.check(Utils.rule( + Block block2 = unverifiedBiscuit1.createBlock(); + block2.addCheck( + Utils.check( + Utils.rule( "caveat2", - List.of(Utils.s("file1")), - List.of( - Utils.pred("resource", List.of(Utils.s("file1"))) - ) - ))); + List.of(Utils.str("file1")), + List.of(Utils.pred("resource", List.of(Utils.str("file1"))))))); - UnverifiedBiscuit unverifiedBiscuit2 = unverifiedBiscuit1.attenuate(rng, keypair2, block2); + UnverifiedBiscuit unverifiedBiscuit2 = unverifiedBiscuit1.attenuate(rng, keypair2, block2); - System.out.println(unverifiedBiscuit2.print()); + System.out.println(unverifiedBiscuit2.print()); - System.out.println("serializing the third token"); + System.out.println("serializing the third token"); - String data2 = unverifiedBiscuit2.serialize_b64url(); + String data2 = unverifiedBiscuit2.serializeBase64Url(); - System.out.print("data len: "); - System.out.println(data2.length()); - System.out.println(data2); + System.out.print("data len: "); + System.out.println(data2.length()); + System.out.println(data2); - System.out.println("deserializing the third token"); - UnverifiedBiscuit finalUnverifiedBiscuit = UnverifiedBiscuit.from_b64url(data2); + System.out.println("deserializing the third token"); + UnverifiedBiscuit finalUnverifiedBiscuit = UnverifiedBiscuit.fromBase64Url(data2); - System.out.println(finalUnverifiedBiscuit.print()); + System.out.println(finalUnverifiedBiscuit.print()); - // Crate Biscuit from UnverifiedBiscuit - Biscuit finalBiscuit = finalUnverifiedBiscuit.verify(keypair0.public_key()); + // Crate Biscuit from UnverifiedBiscuit + Biscuit finalBiscuit = finalUnverifiedBiscuit.verify(keypair0.getPublicKey()); - // check - System.out.println("will check the token for resource=file1 and operation=read"); + // check + System.out.println("will check the token for resource=file1 and operation=read"); - Authorizer authorizer = finalBiscuit.authorizer(); - authorizer.add_fact("resource(\"file1\")"); - authorizer.add_fact("operation(\"read\")"); - authorizer.add_policy("allow if true"); - authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + Authorizer authorizer = finalBiscuit.authorizer(); + authorizer.addFact("resource(\"file1\")"); + authorizer.addFact("operation(\"read\")"); + authorizer.addPolicy("allow if true"); + authorizer.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - System.out.println("will check the token for resource=file2 and operation=write"); + System.out.println("will check the token for resource=file2 and operation=write"); - Authorizer authorizer2 = finalBiscuit.authorizer(); - authorizer2.add_fact("resource(\"file2\")"); - authorizer2.add_fact("operation(\"write\")"); - authorizer2.add_policy("allow if true"); + Authorizer authorizer2 = finalBiscuit.authorizer(); + authorizer2.addFact("resource(\"file2\")"); + authorizer2.addFact("operation(\"write\")"); + authorizer2.addPolicy("allow if true"); - try { - authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); - } catch (Error e) { - System.out.println(e); - assertEquals( - new Error.FailedLogic(new LogicError.Unauthorized(new LogicError.MatchedPolicy.Allow(0), Arrays.asList( - new FailedCheck.FailedBlock(1, 0, "check if resource($resource), operation(\"read\"), right($resource, \"read\")"), - new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")") - ))), - e); - } + try { + authorizer2.authorize(new RunLimits(500, 100, Duration.ofMillis(500))); + } catch (Error e) { + System.out.println(e); + assertEquals( + new Error.FailedLogic( + new LogicError.Unauthorized( + new LogicError.MatchedPolicy.Allow(0), + Arrays.asList( + new FailedCheck.FailedBlock( + 1, + 0, + "check if resource($resource), " + + "operation(\"read\"), right($resource, \"read\")"), + new FailedCheck.FailedBlock(2, 0, "check if resource(\"file1\")")))), + e); } -} \ No newline at end of file + } +}