Skip to content

Commit 97e03a4

Browse files
jiyoonie9svc-squareup-copybara
authored andcommitted
record signature listener implementation in misk-jooq
GitOrigin-RevId: 1d72aa13fa35d80d36e4b12b2ef6a1fa4c148c4c
1 parent 24173b7 commit 97e03a4

File tree

6 files changed

+629
-6
lines changed

6 files changed

+629
-6
lines changed

misk-jooq/api/misk-jooq.api

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ public final class misk/jooq/listeners/AvoidUsingSelectStarListener$Companion {
117117
public final fun getSelectStarFromRegex ()Lkotlin/text/Regex;
118118
}
119119

120+
public final class misk/jooq/listeners/DataIntegrityException : org/jooq/exception/DataAccessException {
121+
public fun <init> (Ljava/lang/String;)V
122+
public fun <init> (Ljava/lang/String;Ljava/lang/Exception;)V
123+
public synthetic fun <init> (Ljava/lang/String;Ljava/lang/Exception;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
124+
}
125+
120126
public final class misk/jooq/listeners/JooqSQLLogger : org/jooq/ExecuteListener {
121127
public static final field Companion Lmisk/jooq/listeners/JooqSQLLogger$Companion;
122128
public fun <init> ()V
@@ -156,6 +162,42 @@ public final class misk/jooq/listeners/JooqTimestampRecordListenerOptions {
156162
public fun toString ()Ljava/lang/String;
157163
}
158164

165+
public abstract interface class misk/jooq/listeners/RecordHasher {
166+
public abstract fun computeMac (Ljava/lang/String;[B)[B
167+
public abstract fun verifyMac (Ljava/lang/String;[B[B)V
168+
}
169+
170+
public final class misk/jooq/listeners/RecordSignatureListener : org/jooq/RecordListener {
171+
public static final field Companion Lmisk/jooq/listeners/RecordSignatureListener$Companion;
172+
public fun <init> (Lmisk/jooq/listeners/RecordHasher;Ljava/util/List;)V
173+
public fun insertStart (Lorg/jooq/RecordContext;)V
174+
public fun loadEnd (Lorg/jooq/RecordContext;)V
175+
public fun updateStart (Lorg/jooq/RecordContext;)V
176+
}
177+
178+
public final class misk/jooq/listeners/RecordSignatureListener$Companion {
179+
public final fun getLog ()Lmu/KLogger;
180+
}
181+
182+
public final class misk/jooq/listeners/TableSignatureDetails {
183+
public fun <init> (Ljava/lang/String;Ljava/util/List;Lorg/jooq/TableField;Lorg/jooq/Table;Z)V
184+
public final fun component1 ()Ljava/lang/String;
185+
public final fun component2 ()Ljava/util/List;
186+
public final fun component3 ()Lorg/jooq/TableField;
187+
public final fun component4 ()Lorg/jooq/Table;
188+
public final fun component5 ()Z
189+
public final fun copy (Ljava/lang/String;Ljava/util/List;Lorg/jooq/TableField;Lorg/jooq/Table;Z)Lmisk/jooq/listeners/TableSignatureDetails;
190+
public static synthetic fun copy$default (Lmisk/jooq/listeners/TableSignatureDetails;Ljava/lang/String;Ljava/util/List;Lorg/jooq/TableField;Lorg/jooq/Table;ZILjava/lang/Object;)Lmisk/jooq/listeners/TableSignatureDetails;
191+
public fun equals (Ljava/lang/Object;)Z
192+
public final fun getAllowNullSignatures ()Z
193+
public final fun getColumns ()Ljava/util/List;
194+
public final fun getSignatureKeyName ()Ljava/lang/String;
195+
public final fun getSignatureRecordColumn ()Lorg/jooq/TableField;
196+
public final fun getTable ()Lorg/jooq/Table;
197+
public fun hashCode ()I
198+
public fun toString ()Ljava/lang/String;
199+
}
200+
159201
public class misk/jooq/testgen/DefaultCatalog : org/jooq/impl/CatalogImpl {
160202
public static final field Companion Lmisk/jooq/testgen/DefaultCatalog$Companion;
161203
public fun <init> ()V
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package misk.jooq.listeners
2+
3+
/**
4+
* Interface for computing and verifying MAC signatures for database records.
5+
*/
6+
interface RecordHasher {
7+
/**
8+
* Computes a MAC signature for the given data using the specified key.
9+
* @param keyName The name of the key to use for signing
10+
* @param data The data to sign
11+
* @return The MAC signature as a byte array
12+
* @throws IllegalArgumentException if the key name is not found
13+
* @throws RuntimeException if signature computation fails
14+
*/
15+
fun computeMac(keyName: String, data: ByteArray): ByteArray
16+
17+
/**
18+
* Verifies a MAC signature for the given data using the specified key.
19+
* @param keyName The name of the key to use for verification
20+
* @param providedMac The MAC to verify
21+
* @param data The data that was signed
22+
* @throws IllegalArgumentException if the key name is not found
23+
* @throws SecurityException if verification fails
24+
*/
25+
fun verifyMac(keyName: String, providedMac: ByteArray, data: ByteArray)
26+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package misk.jooq.listeners
2+
3+
import misk.jooq.toInstant
4+
import java.nio.ByteBuffer
5+
import java.time.LocalDateTime
6+
import java.time.temporal.ChronoUnit
7+
import org.jooq.Record
8+
import org.jooq.RecordContext
9+
import org.jooq.Table
10+
import org.jooq.TableField
11+
import org.jooq.exception.DataAccessException
12+
import misk.logging.getLogger
13+
import org.jooq.RecordListener
14+
15+
/**
16+
* Using this listener will allow you to guard against direct DB updates. This listener will compute a MAC signature
17+
* from the column values provided and store the mac into another column in the DB. When it comes time to retrieve the
18+
* record, the mac will be verified with column data again. If the column data has been changed, then the mac will not
19+
* validate, and you will get an exception.
20+
*
21+
* Things to remember: Add this in when your service is slightly mature. You can't arbitrarily change the columns used
22+
* nor change the column order in creating the mac later on. This will prevent all old rows from being read.
23+
*
24+
* You can add this listener behind a flag and correct the signatures using a backfill and then use this listener again
25+
*
26+
* Also, general Jooq RecordListener rules apply. Specifically - As such, a RecordListener does not affect any bulk DML
27+
* statements (e.g. a DSLContext.update(Table)), whose affected records are not available to clients more info here
28+
* [org.jooq.RecordListener]
29+
*
30+
* Important! If you use the [JooqTimestampRecordListener] make sure this listener is added after the
31+
* [RecordSignatureListener] if you are using created_at and updated_at columns in the signature. So something like this
32+
*
33+
* To see a more the full example see the [RecordSignatureListenerTest]
34+
*/
35+
class RecordSignatureListener(
36+
private val recordHasher: RecordHasher,
37+
private val tableSignatureDetails: List<TableSignatureDetails>,
38+
) : RecordListener {
39+
40+
/**
41+
* We are overriding insertStart and updateStart instead of storeStart() which is called regardless of whether it is
42+
* an update or an insert. The reason is, people generally add another listener - JooqTimestampRecordListener which
43+
* sets the timestamp for created_at and updated_at columns. If these columns form part of the signature then these
44+
* values need to be set before the signature can be calculated. The JooqTimestampRecordListener has insertStart and
45+
* updateStart overridden. Jooq looks to be calling storeStart() on all listeners before moving on to call
46+
* insertStart() and updateStart() on all listeners. So either all listeners need to implement storeStart() or
47+
* insertStart() and updateStart(). We can mix the 2 styles and guarantee an order.
48+
*/
49+
override fun insertStart(ctx: RecordContext?) = updateSignature(ctx)
50+
51+
override fun updateStart(ctx: RecordContext?) = updateSignature(ctx)
52+
53+
private fun updateSignature(ctx: RecordContext?) {
54+
if (ctx?.record() == null) return
55+
val tableSignature = tableSignatureDetails.find { ctx.record().field(it.signatureRecordColumn) != null } ?: return
56+
57+
val concatenatedByteArray = concatenateByteArrayFromColumnValues(tableSignature, ctx)
58+
val signature = recordHasher.computeMac(tableSignature.signatureKeyName, concatenatedByteArray)
59+
ctx.record().set(tableSignature.signatureRecordColumn, signature)
60+
}
61+
62+
override fun loadEnd(ctx: RecordContext?) {
63+
if (ctx?.record() == null) return
64+
val tableSignature = tableSignatureDetails.find { ctx.record().field(it.signatureRecordColumn) != null } ?: return
65+
66+
// Skip validation if all signature columns are null (indicates a partially loaded record)
67+
val allColumnsNull = tableSignature.columns.all { column ->
68+
ctx.record().get(column) == null
69+
}
70+
if (allColumnsNull) return
71+
72+
val concatenatedByteArray = concatenateByteArrayFromColumnValues(tableSignature, ctx)
73+
val signature =
74+
ctx.record().get(tableSignature.signatureRecordColumn)
75+
?: if (!tableSignature.allowNullSignatures) {
76+
throw DataIntegrityException(exceptionMessage("Signature is null", tableSignature, ctx))
77+
} else {
78+
return
79+
}
80+
81+
try {
82+
recordHasher.verifyMac(tableSignature.signatureKeyName, signature, concatenatedByteArray)
83+
} catch (e: Exception) {
84+
log.warn(e) {
85+
exceptionMessage("The data in the database does not match the record signature on the", tableSignature, ctx)
86+
}
87+
88+
throw DataIntegrityException(
89+
exceptionMessage("The data in the database does not match the record signature on the", tableSignature, ctx),
90+
cause = e,
91+
)
92+
}
93+
}
94+
95+
private fun concatenateByteArrayFromColumnValues(
96+
tableSignature: TableSignatureDetails,
97+
ctx: RecordContext,
98+
): ByteArray {
99+
/**
100+
Here, we have implemented LV (length-value) encoding scheme.
101+
Encoding the column values and concatenating them as byte array in this manner
102+
prevents these two distinct records creating the same signature,
103+
given that signature is built using values from foo and bar columns.
104+
105+
Encoding scheme:
106+
- null: 4 bytes with value -1 (no data follows)
107+
- non-null: 4 bytes with length >= 0, followed by that many bytes of data
108+
109+
more info here: https://en.wikipedia.org/wiki/Type%E2%80%93length%E2%80%93value
110+
111+
without LV encoding
112+
id | foo | bar |
113+
1 | ab | c | bytearray(ab) + bytearray(c)
114+
2 | a | bc | bytearray(a) + bytearray(bc)
115+
result: the two bytearrays from record 1 and record 2 are the same
116+
117+
with LV encoding
118+
id | foo | bar |
119+
1 | ab | c | (lengthByte(2) + bytearray(ab)) + (lengthByte(1) + (bytearray(c))
120+
2 | a | bc | (lengthByte(1) + bytearray(a)) + (lengthByte(2) + bytearray(bc))
121+
result: the two bytearrays from record 1 and record 2 are NOT the same
122+
123+
We also encode null values with a special marker (-1) to prevent collisions like:
124+
id | foo | bar |
125+
1 | null | a | bytearray(special_for_null) + (lengthByte(1) + bytearray(a))
126+
2 | a | null | (lengthByte(1) + bytearray(a)) + bytearray(special_for_null)
127+
128+
bytearray(special_for_null) cannot be conflicted with other real values
129+
*/
130+
return tableSignature.columns.fold(ByteArray(0)) { bytes, column ->
131+
when (val columnValue = ctx.record().get(column)) {
132+
// For null values, encode with -1 as a special marker (no value bytes follow)
133+
null -> {
134+
val nullMarker = ByteBuffer.allocate(4).putInt(-1).array()
135+
bytes + nullMarker
136+
}
137+
138+
// For ByteArray values, prepend the length (4 bytes) then the value
139+
is ByteArray -> {
140+
val lengthBytes = ByteBuffer.allocate(4).putInt(columnValue.size).array()
141+
bytes + lengthBytes + columnValue
142+
}
143+
144+
// For LocalDateTime, convert to bytes first, then apply Length-Value encoding
145+
is LocalDateTime -> {
146+
val precision = column.dataType.precision()
147+
val valueBytes = columnValue.toByteArray(precision)
148+
val lengthBytes = ByteBuffer.allocate(4).putInt(valueBytes.size).array()
149+
bytes + lengthBytes + valueBytes
150+
}
151+
// For all other types, convert to string, then to bytes, then apply Length-Value encoding
152+
else -> {
153+
val valueBytes = columnValue.toString().toByteArray()
154+
val lengthBytes = ByteBuffer.allocate(4).putInt(valueBytes.size).array()
155+
bytes + lengthBytes + valueBytes
156+
}
157+
}
158+
}
159+
}
160+
161+
/**
162+
* MySQL's precision for a timestamp is millis. But in the Kube Pod, where the code runs the JVM timestamp is in
163+
* nanos. So when we store the data, the signature is computed with nanos, but when we load the data from the DB, the
164+
* nanos are lost and hence the signature computed is different. This method truncates the instant based on the
165+
* precision. The check with precision is required to be able to test this on a MAC. Mac JVM's precision is millis. So
166+
* in order to test truncation we need to create a mysql timestamp with a precision of 0. This also allows this
167+
* signature to work for any column created in prod where the precision is 0 (in the sense, restricted to store
168+
* seconds alone).
169+
*/
170+
private fun LocalDateTime.toByteArray(precision: Int): ByteArray {
171+
return when {
172+
precision < 3 -> toInstant().truncatedTo(ChronoUnit.SECONDS).toEpochMilli().toString().toByteArray()
173+
else -> toInstant().truncatedTo(ChronoUnit.MILLIS).toEpochMilli().toString().toByteArray()
174+
}
175+
}
176+
177+
private fun exceptionMessage(message: String, tableSignature: TableSignatureDetails, ctx: RecordContext): String {
178+
return message +
179+
" [Table=${tableSignature.table}] " +
180+
"[PK=${tableSignature.table.primaryKey?.fields?.map { ctx.record().get(it) }?.joinToString(", ")}]"
181+
}
182+
183+
companion object {
184+
val log = getLogger<RecordSignatureListener>()
185+
}
186+
}
187+
188+
data class TableSignatureDetails(
189+
/**
190+
* The key name used to create the HMAC signature More details here -
191+
* https://cash-dev-guide.squarecloudservices.com/security/key_management/ and
192+
* https://github.com/google/tink/blob/master/docs/PRIMITIVES.md
193+
*/
194+
val signatureKeyName: String,
195+
/**
196+
* The columns that need to be protected against direct change in the database. Please note: the value of these
197+
* columns should be convertable deterministically into a string value or should be a byte array already such as BLOB
198+
* types. Most SQL value types can be converted into a string via toString() call. Note:
199+
* 1. JSON columns cannot be used as part of the signature columns as the string comparison of a JSON differs if there
200+
* are whitespace differences. MYSQL does not store JSON as a string and hence when it is retrieved there usually
201+
* are white space differences.
202+
* 2. If a timestamp column is used in the signature, remember that MYSQL's precision is limited to millis. The MAC
203+
* JVM precision is limited to millis too. But the Kube pod where this is deployed has nano precision. So ensure
204+
* the timestamp is truncated to millis before setting it into the record.
205+
*/
206+
val columns: List<TableField<out Record, out Any?>>,
207+
/** The column where the HMAC signature (or hash) will be stored and then used to validate against */
208+
val signatureRecordColumn: TableField<out Record, ByteArray?>,
209+
/** The table that needs to be protected against direct change in the database. */
210+
val table: Table<out Record>,
211+
/**
212+
* When adding this listener to an existing table, set this flag to true until you are sure that all records in the
213+
* table have a signature set
214+
*/
215+
val allowNullSignatures: Boolean,
216+
)
217+
218+
class DataIntegrityException @JvmOverloads constructor(message: String, cause: Exception? = null) : DataAccessException(message, cause)

misk-jooq/src/test/kotlin/misk/jooq/config/ClientJooqTestingModule.kt

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@ import misk.jooq.JooqModule
1111
import misk.jooq.listeners.JooqTimestampRecordListenerOptions
1212
import misk.logging.LogCollectorModule
1313
import org.jooq.impl.DefaultExecuteListenerProvider
14+
import org.jooq.impl.DefaultRecordListenerProvider
1415
import wisp.deployment.TESTING
1516
import jakarta.inject.Qualifier
17+
import misk.jooq.listeners.RecordSignatureListener
18+
import misk.jooq.listeners.TableSignatureDetails
19+
import misk.jooq.testgen.tables.references.RECORD_SIGNATURE_TEST
20+
import org.jooq.RecordListenerProvider
1621

1722
class ClientJooqTestingModule : KAbstractModule() {
1823
override fun configure() {
@@ -46,12 +51,45 @@ class ClientJooqTestingModule : KAbstractModule() {
4651
createdAtColumnName = "created_at",
4752
updatedAtColumnName = "updated_at"
4853
),
49-
readerQualifier = JooqDBReadOnlyIdentifier::class
50-
) {
51-
val executeListeners = this.executeListenerProviders().toMutableList()
52-
.apply { add(DefaultExecuteListenerProvider(DeleteOrUpdateWithoutWhereListener())) }
53-
set(*executeListeners.toTypedArray())
54-
})
54+
readerQualifier = JooqDBReadOnlyIdentifier::class,
55+
jooqConfigExtension = {
56+
// Since JooqTimestampRecordListener might overwrite record listeners,
57+
// we need to ensure both timestamp and signature listeners are present
58+
val recordListeners = mutableListOf<RecordListenerProvider>()
59+
60+
// Add any existing record listeners (like JooqTimestampRecordListener)
61+
recordListeners.addAll(this.recordListenerProviders())
62+
63+
// Add our RecordSignatureListener
64+
recordListeners.add(
65+
DefaultRecordListenerProvider(
66+
RecordSignatureListener(
67+
recordHasher = FakeRecordHasher(),
68+
tableSignatureDetails = listOf(
69+
TableSignatureDetails(
70+
signatureKeyName = "signature-record-test",
71+
table = RECORD_SIGNATURE_TEST,
72+
columns = listOf(
73+
RECORD_SIGNATURE_TEST.NAME,
74+
RECORD_SIGNATURE_TEST.UPDATED_BY,
75+
RECORD_SIGNATURE_TEST.BINARY_DATA,
76+
),
77+
signatureRecordColumn = RECORD_SIGNATURE_TEST.RECORD_SIGNATURE,
78+
allowNullSignatures = false
79+
)
80+
),
81+
)
82+
)
83+
)
84+
85+
set(*recordListeners.toTypedArray())
86+
87+
// Add execute listeners separately
88+
val existingExecuteListeners = this.executeListenerProviders().toMutableList()
89+
existingExecuteListeners.add(DefaultExecuteListenerProvider(DeleteOrUpdateWithoutWhereListener()))
90+
set(*existingExecuteListeners.toTypedArray())
91+
}
92+
))
5593
install(JdbcTestingModule(JooqDBIdentifier::class))
5694
install(LogCollectorModule())
5795
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package misk.jooq.config
2+
3+
import misk.jooq.listeners.RecordHasher
4+
5+
// Fake implementation for testing
6+
class FakeRecordHasher : RecordHasher {
7+
8+
override fun computeMac(keyName: String, data: ByteArray): ByteArray {
9+
// Simple hash for testing - not secure!
10+
val fakeMac = "$keyName:${data.contentHashCode()}".toByteArray()
11+
return fakeMac
12+
}
13+
14+
override fun verifyMac(keyName: String, providedMac: ByteArray, data: ByteArray) {
15+
val expectedFakeMac = "$keyName:${data.contentHashCode()}".toByteArray()
16+
if (!providedMac.contentEquals(expectedFakeMac)) {
17+
throw SecurityException("MAC verification failed")
18+
}
19+
}
20+
}

0 commit comments

Comments
 (0)