Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import dev.slimevr.serial.SerialPort
import dev.slimevr.tracking.trackers.Tracker
import dev.slimevr.tracking.trackers.TrackerStatus
import dev.slimevr.tracking.trackers.TrackerStatusListener
import dev.slimevr.tracking.trackers.udp.MCUType
import dev.slimevr.tracking.trackers.udp.UDPDevice
import io.eiren.util.logging.LogManager
import kotlinx.coroutines.*
Expand Down Expand Up @@ -101,11 +102,27 @@ class FirmwareUpdateHandler(private val server: VRServer) :
)
return@suspendCancellableCoroutine
}

// TODO:
// - Use the Firmware Builder to get the expected MCU
// It would be wrong to assume that the Target MCU is the correct one,
// just because the device is listening on the correct port.
// The Upload protocol does not verify the compatibility of the firmware with the MCU.

val port = when (udpDevice.mcuType) {
MCUType.ESP8266 -> 8266
MCUType.ESP32, MCUType.ESP32_C3 -> 3232
else -> error("MCU-Typ: ${udpDevice.mcuType} not supported for OTA updates")
}

LogManager.info("[FirmwareUpdateHandler] Starting OTA update for device ${deviceId.id} at ${udpDevice.ipAddress.hostAddress}:$port and MCU ${udpDevice.mcuType}")

val task = OTAUpdateTask(
part.firmware,
deviceId,
udpDevice.ipAddress,
::onStatusChange,
port,
)
c.invokeOnCancellation {
task.cancel()
Expand Down Expand Up @@ -156,6 +173,15 @@ class FirmwareUpdateHandler(private val server: VRServer) :
flasher.addBin(part.firmware, part.offset.toInt())
}

// TODO:
// - Check if FW is able to use flashmode
// - Add check if the flashmode was successfully set to surpress the request
// for manual flashmode setting prompt in gui

server.serialHandler.openSerial(deviceId.id, false)
server.serialHandler.write("SET FLASHMODE\r\n".toByteArray())
server.serialHandler.closeSerial()

flasher.addProgressListener(object : FlashingProgressListener {
override fun progress(progress: Float) {
onStatusChange(
Expand Down
53 changes: 43 additions & 10 deletions server/core/src/main/java/dev/slimevr/firmware/OTAUpdateTask.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import java.util.*
import java.util.function.Consumer
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.PBEKeySpec
import kotlin.math.min

class OTAUpdateTask(
private val firmware: ByteArray,
private val deviceId: UpdateDeviceId<Int>,
private val deviceIp: InetAddress,
private val statusCallback: Consumer<UpdateStatusEvent<Int>>,
private val port: Int = 8266,
) {
private val receiveBuffer: ByteArray = ByteArray(38)
private val receiveBuffer: ByteArray = ByteArray(69)
var socketServer: ServerSocket? = null
var uploadSocket: Socket? = null
var authSocket: DatagramSocket? = null
Expand All @@ -40,17 +43,35 @@ class OTAUpdateTask(
return md5str.toString()
}

@Throws(NoSuchAlgorithmException::class)
private fun bytesToSha256(bytes: ByteArray): String {
val sha256 = MessageDigest.getInstance("SHA-256")
sha256.update(bytes)
val digest = sha256.digest()
val sha256str = StringBuilder()
for (b in digest) {
sha256str.append(String.format("%02x", b))
}
return sha256str.toString()
}

fun pbkdf2Hmac(password: String, salt: ByteArray, iterations: Int, keyLength: Int): ByteArray {
val spec = PBEKeySpec(password.toCharArray(), salt, iterations, keyLength * 8)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
return factory.generateSecret(spec).encoded
}

private fun authenticate(localPort: Int): Boolean {
try {
DatagramSocket().use { socket ->
authSocket = socket
statusCallback.accept(UpdateStatusEvent(deviceId, FirmwareUpdateStatus.AUTHENTICATING))
LogManager.info("[OTAUpdate] Sending OTA invitation to: $deviceIp")
LogManager.info("[OTAUpdate] Sending OTA invitation to: $deviceIp:$port")

val fileMd5 = bytesToMd5(firmware)
val message = "$FLASH $localPort ${firmware.size} $fileMd5\n"

socket.send(DatagramPacket(message.toByteArray(), message.length, deviceIp, PORT))
socket.send(DatagramPacket(message.toByteArray(), message.length, deviceIp, port))
socket.soTimeout = 10000

val authPacket = DatagramPacket(receiveBuffer, receiveBuffer.size)
Expand All @@ -68,13 +89,26 @@ class OTAUpdateTask(
if (args.size != 2 || args[0] != "AUTH") return false

LogManager.info("[OTAUpdate] Authenticating...")
var payload = ""
var signature = ""

val authToken = args[1]
val signature = bytesToMd5(UUID.randomUUID().toString().toByteArray())
val hashedPassword = bytesToMd5(PASSWORD.toByteArray())
val resultText = "$hashedPassword:$authToken:$signature"
val payload = bytesToMd5(resultText.toByteArray())

if (authToken.length == 32) {
signature =
bytesToMd5(UUID.randomUUID().toString().toByteArray())
val hashedPassword = bytesToMd5(PASSWORD.toByteArray())
val resultText = "$hashedPassword:$authToken:$signature"
payload = bytesToMd5(resultText.toByteArray())
} else if (authToken.length == 64) {
signature =
bytesToSha256(UUID.randomUUID().toString().toByteArray())
val salt = "$authToken:$signature"
val hashedPassword = bytesToSha256(PASSWORD.toByteArray())
val derivedkey = pbkdf2Hmac(hashedPassword, salt.toByteArray(), 10000, 32)
val derivedkeyHex = derivedkey.joinToString("") { "%02x".format(it) }
val challenge = "$derivedkeyHex:$authToken:$signature"
payload = bytesToSha256(challenge.toByteArray())
}
val authMessage = "$AUTH $signature $payload\n"

socket.soTimeout = 10000
Expand All @@ -83,7 +117,7 @@ class OTAUpdateTask(
authMessage.toByteArray(),
authMessage.length,
deviceIp,
PORT,
port,
),
)

Expand Down Expand Up @@ -198,7 +232,6 @@ class OTAUpdateTask(

companion object {
private const val FLASH = 0
private const val PORT = 8266
private const val PASSWORD = "SlimeVR-OTA"
private const val AUTH = 200
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class DesktopSerialHandler :
}

override fun write(buff: ByteArray) {
LogManager.info("[SerialHandler] WRITING $buff")
LogManager.info("[SerialHandler] WRITING ${buff.toString(Charsets.UTF_8)}")
currentPort?.outputStream?.write(buff)
}

Expand Down