diff --git a/ishlib/formats/fastxpp_bench.mojo b/ishlib/formats/fastxpp_bench.mojo new file mode 100644 index 0000000..3899caf --- /dev/null +++ b/ishlib/formats/fastxpp_bench.mojo @@ -0,0 +1,108 @@ +import sys +from time.time import perf_counter +from ishlib.vendor.kseq import FastxReader, BufferedReader +from ishlib.vendor.zlib import GZFile +from ExtraMojo.utils.ir import dump_ir + + +fn bench_original(path: String) raises -> (Int, Int, Float64): + var rdr = FastxReader[read_comment=False](BufferedReader(GZFile(path, "r"))) + var rec = 0 + var seq = 0 + var t0 = perf_counter() + while rdr.read() > 0: + rec += 1 + seq += len(rdr.seq) + return (rec, seq, perf_counter() - t0) + + +fn bench_fastxpp_strip_newline(path: String) raises -> (Int, Int, Float64): + var rdr = FastxReader[read_comment=False](BufferedReader(GZFile(path, "r"))) + var rec = 0 + var seq = 0 + var t0 = perf_counter() + while True: + var n = rdr.read_fastxpp_strip_newline() + if n < 0: + break + rec += 1 + seq += n + return (rec, seq, perf_counter() - t0) + + +fn bench_fastxpp_swar(path: String) raises -> (Int, Int, Float64): + var rdr = FastxReader[read_comment=False](BufferedReader(GZFile(path, "r"))) + var rec = 0 + var seq = 0 + var t0 = perf_counter() + while True: + var n = rdr.read_fastxpp_swar() + if n < 0: + break + rec += 1 + seq += n + return (rec, seq, perf_counter() - t0) + + +fn bench_fastxpp_read_once(path: String) raises -> (Int, Int, Float64): + var rdr = FastxReader[read_comment=False](BufferedReader(GZFile(path, "r"))) + var rec = 0 + var seq = 0 + var t0 = perf_counter() + while True: + var n = rdr.read_fastxpp_read_once() + if n < 0: + break + rec += 1 + seq += n + return (rec, seq, perf_counter() - t0) + + +fn main() raises: + var argv = sys.argv() + if len(argv) < 2 or len(argv) > 3: + print("Usage: mojo run fastxpp_bench.mojo [orig|fastxpp|bpl]") + return + + var path = String(argv[1]) + var mode: String = "orig" # default when no flag given + if len(argv) == 3: + mode = String(argv[2]) + + if mode == "orig": + r, s, t = bench_original(path) + print( + "mode=orig records=", r, " bases=", s, " time=", t, "s" + ) + elif mode == "strip_newline": + r, s, t = bench_fastxpp_strip_newline(path) + print( + "mode=read_fastxpp_strip_newline records=", + r, + " bases=", + s, + " time=", + t, + "s", + ) + elif mode == "swar": + r, s, t = bench_fastxpp_swar(path) + print( + "mode=fastxpp_swar records=", r, " bases=", s, " time=", t, "s" + ) + elif mode == "read_once": + r, s, t = bench_fastxpp_read_once(path) + print("mode=read_once records=", r, " bases=", s, " time=", t, "s") + elif mode == "filler": + r, s, t = bench_fastxpp_read_once(path) + print( + "mode=bench_fastxpp_read_once records=", + r, + " bases=", + s, + " time=", + t, + "s", + ) + else: + print("Unknown mode:", mode) diff --git a/ishlib/formats/generate_fastxpp.mojo b/ishlib/formats/generate_fastxpp.mojo new file mode 100644 index 0000000..52458af --- /dev/null +++ b/ishlib/formats/generate_fastxpp.mojo @@ -0,0 +1,183 @@ +import sys +from collections import Optional +from ExtraMojo.io.buffered import BufferedReader, BufferedWriter +from collections import List # dynamic grow-able buffer +from memory import Span # view into the List for zero-copy writes + +# ---------- helpers ------------------------------------------------- + + +fn string_count(s: String) -> Int: + var n: Int = 0 + for _ in s.codepoints(): + n = n + 1 + return n + + +fn read_line(mut rdr: BufferedReader) raises -> String: + var buf = List[UInt8]() + var n = rdr.read_until(buf, ord("\n")) + if n == 0: + return "" + var s = String() + s.write_bytes(Span(buf)) + return s + + +# ---------- FASTX++ builder ----------------------------------------- + + +fn generate_fastxpp( + marker: String, + header: String, + seq_lines: List[String], + qualities: Optional[List[String]] = None, +) -> String: + var bpl = string_count(seq_lines[0]) + 1 # bases + LF + var seq_len: Int = 0 + for i in range(len(seq_lines)): + seq_len = seq_len + string_count(seq_lines[i]) + + var meta = String(string_count(header)) + ":" + String( + seq_len + ) + ":" + String(len(seq_lines)) + + var rec = marker + "`" + meta + "`" + header + "\n" + + for i in range(len(seq_lines)): + rec.write(seq_lines[i], "\n") + + if qualities: + var q = qualities.value() + rec += "+\n" + for i in range(len(q)): + rec.write(q[i], "\n") + + return rec + + +fn generate_fastxpp_bpl( + marker: String, + header: String, + seq_lines: List[String], + qualities: Optional[List[String]] = None, +) -> String: + var bpl = string_count(seq_lines[0]) + 1 # bases + LF + var slen = (bpl - 1) * (len(seq_lines) - 1) + # (bases per full line) + string_count(seq_lines[-1]) # + last (ragged) line + var meta = String(string_count(header)) + ":" + + String(slen) + ":" + + String(len(seq_lines)) + ":" + + String(bpl) + var rec = marker + "`" + meta + "`" + header + "\n" + for i in range(len(seq_lines)): + rec.write(seq_lines[i], "\n") + if qualities: + var q = qualities.value() + for i in range(len(q)): + rec.write(q[i], "\n") + return rec + +# Helper: encode an unsigned ≤9-digit value as zero-padded ASCII. +fn to_ascii_padded(value: Int, width: Int) -> String: + # build the decimal text first … + var digits = String(value) # e.g. "123" + var pad = width - string_count(digits) # how many zeros needed + + # … then emit into a single pre-sized String + var out = String(capacity=width) + for _ in range(pad): + out.write("0") + out.write(digits) # concat is zero-copy + return out # length == width + +fn generate_fastxpp_bpl_fixed( + marker: String, + header: String, + seq_lines: List[String], + qualities: Optional[List[String]] = None, +) -> String: + + # --- numeric fields ------------------------------------------------ + var bpl = string_count(seq_lines[0]) + 1 # incl. LF + var slen = (bpl - 1) * (len(seq_lines) - 1) + + string_count(seq_lines[-1]) + + # --- fixed-width metadata block ------------------------------------ + var meta = "`" + + #to_ascii_padded(string_count(header), 6) + # hlen + to_ascii_padded(slen, 9) + # slen + to_ascii_padded(len(seq_lines), 7) + # nlin + to_ascii_padded(bpl, 3) + # bpl + "`" + + # --- assemble record ----------------------------------------------- + var rec = marker + meta + header + "\n" + for i in range(len(seq_lines)): + rec.write(seq_lines[i], "\n") + if qualities: + var q = qualities.value() + for i in range(len(q)): + rec.write(q[i], "\n") + return rec + +# ---------- main ---------------------------------------------------- + + +fn main() raises: + var argv = sys.argv() + if len(argv) != 3: + print( + "Usage: mojo run generate_fastxpp.mojo " + " " + ) + return + + var reader = BufferedReader( + open(String(argv[1]), "r"), buffer_capacity=128 * 1024 + ) + var writer = BufferedWriter( + open(String(argv[2]), "w"), buffer_capacity=128 * 1024 + ) + + var pending_header = String() # carries a header we already read + + while True: + var header_line = pending_header + if header_line == "": + header_line = read_line(reader) + pending_header = String() + + if header_line == "": + break + + var marker = String(header_line[0:1]) + var header = String(header_line[1:]) + + var seq = List[String]() + var line: String + + while True: + line = read_line(reader) + if line == "": + break + if ( + line.startswith(">") + or line.startswith("@") + or (marker == "@" and line.startswith("+")) + ): + pending_header = line # save for the next record + break + seq.append(line) + + var qual: Optional[List[String]] = None + if marker == "@" and line.startswith("+"): + var qlines = List[String]() + for _ in range(len(seq)): + qlines.append(read_line(reader)) + qual = Optional[List[String]](qlines) + + writer.write(generate_fastxpp_bpl_fixed(marker, header, seq, qual)) + + writer.flush() + writer.close() diff --git a/ishlib/vendor/kseq.mojo b/ishlib/vendor/kseq.mojo index 2af6dcb..5116eae 100644 --- a/ishlib/vendor/kseq.mojo +++ b/ishlib/vendor/kseq.mojo @@ -20,14 +20,15 @@ def main(): print(count, slen, qlen, sep="\t") ``` """ +import sys from memory import UnsafePointer, memcpy from utils import StringSlice +from ishlib.vendor.swar_decode import decode from time.time import perf_counter from ExtraMojo.bstr.memchr import memchr - alias ASCII_NEWLINE = ord("\n") alias ASCII_CARRIAGE_RETURN = ord("\r") alias ASCII_TAB = ord("\t") @@ -35,6 +36,60 @@ alias ASCII_SPACE = ord(" ") alias ASCII_FASTA_RECORD_START = ord(">") alias ASCII_FASTQ_RECORD_START = ord("@") alias ASCII_FASTQ_SEPARATOR = ord("+") +alias ASCII_ZERO = UInt8(ord("0")) + + +# ────────────────────────────────────────────────────────────── +# Helpers for reading in fastx++ files +# ────────────────────────────────────────────────────────────── + + +@always_inline +fn strip_newlines_in_place( + mut bs: ByteString, disk: Int, expected: Int +) -> Bool: + """Compact `bs` by removing every `\n` byte in‑place; return True if the + resulting length equals `expected`. + SIMD search for newline, shifts the bytes to the left, and resizes the buffer. + Avoids allocating a new buffer and copying the data. + + bs: Mutable buffer that already holds the raw FASTQ/FASTA chunk just read from disk + disk: The number of bytes that were actually read into bs + expected: how many bases/quality bytes should remain after stripping newlines; + used as a quick integrity check. + + Returns: + True if the resulting buffer's length equals `expected`, False otherwise. + """ + # read_pos always starts the loop at the first byte that has not yet been examined. + var read_pos: Int = 0 + # write_pos always starts at the first byte that has not yet been written into its final position + var write_pos: Int = 0 + # Before the first newline, every byte is kept, so the pointers march together (no gap) + # After the first newline, the pointers may diverge, and we will need to copy bytes + + while read_pos < disk: + var span_rel = memchr[do_alignment=False]( + Span[UInt8, __origin_of(bs.ptr)]( + ptr=bs.ptr.offset(read_pos), length=disk - read_pos + ), + UInt8(ASCII_NEWLINE), + ) + # If there are no new lines we dont have to adjust buffer + # If there was newlines, compute the contiguous span without newlines + var end_pos = disk if span_rel == -1 else read_pos + span_rel + var span_len = end_pos - read_pos + # We only need to copy if there are newlines that would made gaps resulting in write_pos != read_pos + # See read_pos and write_pos comments above + if span_len > 0 and write_pos != read_pos: + memcpy(bs.ptr.offset(write_pos), bs.ptr.offset(read_pos), span_len) + write_pos += span_len + read_pos = end_pos + 1 # skip the '\n' (or exit loop if none) + bs.resize(write_pos) + return write_pos == expected + + +# ────────────────────────────────────────────────────────────── @value @@ -343,6 +398,7 @@ struct FastxReader[R: KRead, read_comment: Bool = True](Movable): self.seq = ByteString(256) self.qual = ByteString() self.comment = ByteString() + # Special comment field is 26 bytes long +1 from backtick self.last_char = 0 fn __moveinit__(out self, owned other: Self): @@ -384,7 +440,7 @@ struct FastxReader[R: KRead, read_comment: Bool = True](Movable): if c < 0: return Int(c) # EOF Error - # Reset all members + # Reset all buffers for reuse self.seq.clear() self.qual.clear() self.comment.clear() @@ -455,3 +511,248 @@ struct FastxReader[R: KRead, read_comment: Bool = True](Movable): if len(self.qual) != len(self.seq): return -2 # error: qual string is of different length return len(self.seq) + + fn read_fastxpp_strip_newline(mut self) raises -> Int: + # ── 0 Locate the next header byte ('>' or '@') ────────────────────── + var marker: UInt8 + if self.last_char == 0: + var c = self.reader.read_byte() + while ( + c >= 0 + and c != ASCII_FASTA_RECORD_START + and c != ASCII_FASTQ_RECORD_START + ): + c = self.reader.read_byte() + if c < 0: + # EOF or stream error + return Int(c) + marker = UInt8(c) + else: + marker = UInt8(self.last_char) + self.last_char = 0 + + # ── 1 Reset buffers reused across records ─────────────────────────── + self.seq.clear() + self.qual.clear() + self.comment.clear() + self.name.clear() + + # ── 2 Read the back‑tick header line -------------------------------- + var r = self.reader.read_until[SearchChar.Newline](self.name, 0, False) + if r < 0: + return Int(r) + + var hdr = self.name.as_span() + if len(hdr) == 0 or hdr[0] != UInt8(ord("`")): + # We need at least 21 bytes: 1 backtick + 9 + 7 + 3 + 1 backtick + return -3 # Not a proper FASTX++ header + + # ── 3 Decode slen:lcnt:bpl from the fixed fields -------------------- + var slen = decode[9](hdr[1:10]) # 9‑digit field at positions [1..9] + var lcnt = decode[7](hdr[10:17]) # 7‑digit field at positions [10..16] + var bpl = decode[3](hdr[17:20]) # 3‑digit field at positions [17..19] + + # Confirm the second backtick is at hdr[20] + if hdr[20] != UInt8(ord("`")): + return -3 + + # ── 4 Read the sequence block (slen + lcnt bytes on disk) ----------- + var disk_seq = slen + lcnt + self.seq.reserve(disk_seq) + var _disk_seq = disk_seq + var got_seq = self.reader.read_bytes(self.seq, _disk_seq) + if got_seq != disk_seq: + return -3 # truncated record + + # ── 5 Remove newline characters in‑place using the helper ----------- + var ok = strip_newlines_in_place(self.seq, disk_seq, slen) + if not ok: + return -2 # mismatch: not the expected base count + + return len(self.seq) + + fn read_fastxpp_swar(mut self) raises -> Int: + # ── 0 Locate the next header byte ('>' or '@') ────────────────────── + var marker: UInt8 # remember which flavour we’re on + if self.last_char == 0: + var c = self.reader.read_byte() + while ( + c >= 0 + and c != ASCII_FASTA_RECORD_START + and c != ASCII_FASTQ_RECORD_START + ): + c = self.reader.read_byte() + if c < 0: + return Int(c) # EOF / stream error (-1 / -2) + marker = UInt8(c) + else: + marker = UInt8(self.last_char) + var c = self.last_char + self.last_char = 0 + + # ── 1 Reset buffers reused across records ─────────────────────────── + self.seq.clear() + self.qual.clear() + self.comment.clear() + self.name.clear() + + # ── 2 Read the back‑tick header line -------------------------------- + var r = self.reader.read_until[SearchChar.Newline](self.name, 0, False) + if r < 0: + return Int(r) + + var hdr = self.name.as_span() + if len(hdr) == 0 or hdr[0] != UInt8(ord("`")): + print("ERROR: Opening backtick check failed. hdr[0]") + return -3 # not a FASTX++ BPL header + + # useful debugging + # for i in range(len(hdr)): + # var code = Int(hdr[i]) + # print(i, hdr[i], code, chr(code)) + + # ── 3 Find closing back‑tick and parse slen:lcnt:bpl ----------- + var slen = decode[9](hdr[1:10]) # bytes 1–9 + var lcnt = decode[7](hdr[10:17]) # bytes 10–16 + var bpl = decode[3](hdr[17:20]) # bytes 17–19 + + if hdr[20] != UInt8(ord("`")): + print("ERROR: Closing backtick check failed. base[25]") + return -3 + + # ── 4 SEQUENCE block --------------------------------------- + var disk_seq = slen + lcnt # immutable reference + var rest_seq = disk_seq # mutable copy for read_bytes + + self.seq.reserve(disk_seq) + var got_seq = self.reader.read_bytes(self.seq, rest_seq) + if got_seq != disk_seq: + print("ERROR: Sequence read failed. got_seq != disk_seq") + return -3 # truncated record + + # compact in‑place: copy (bpl‑1) bases, skip the LF, repeat + var write_pos: Int = 0 + var read_pos: Int = 0 + while read_pos < disk_seq: + memcpy( + self.seq.addr(write_pos), # destination + self.seq.addr(read_pos), # source + bpl - 1, + ) # copy only the bases + write_pos += bpl - 1 + read_pos += bpl # jump over the LF + self.seq.resize(write_pos) # write_pos == slen + + return len(self.seq) + + fn read_fastxpp_read_once(mut self) raises -> Int: + # ── 0 Locate the next header byte ('>' or '@') ────────────────────── + var marker: UInt8 + if self.last_char == 0: + var c = self.reader.read_byte() + while ( + c >= 0 + and c != ASCII_FASTA_RECORD_START + and c != ASCII_FASTQ_RECORD_START + ): + c = self.reader.read_byte() + if c < 0: + return Int(c) + marker = UInt8(c) + else: + marker = UInt8(self.last_char) + self.last_char = 0 + + # ── 1 Reset buffers reused across records ─────────────────────────── + self.seq.clear() + self.qual.clear() + self.comment.clear() + self.name.clear() + + # ── 2 Read the back‑tick header line -------------------------------- + var r = self.reader.read_until[SearchChar.Newline](self.name, 0, False) + if r < 0: + return Int(r) + + var hdr = self.name.as_span() + # ── 3 Find closing back‑tick and parse slen:lcnt:bpl ----------- + if len(hdr) == 0 or hdr[0] != UInt8(ord("`")): + print("ERROR: header lacks opening back-tick") + return -3 + + var slen = decode[9](hdr[1:10]) + var lcnt = decode[7](hdr[10:17]) # not needed in this approach + var bpl = decode[3](hdr[17:20]) + + if hdr[20] != UInt8(ord("`")): + print("ERROR: header lacks closing back-tick") + return -3 + + # ── 4 SEQUENCE block --------------------------------------- + self.seq.reserve(UInt32(slen)) + var copied: Int = 0 + while copied < slen: + var want = min(bpl - 1, slen - copied) + + # track length before and after + var before = len(self.seq) + var _want = want + var _total = self.reader.read_bytes(self.seq, _want) + if _total < 0: + print("ERROR: read_bytes returned error", _total) + return -3 + var got = Int(_total) - before # true delta + + if got != want: + return -3 + + copied += got + + # consume newline + var nl = self.reader.read_byte() + if nl != ASCII_NEWLINE: + print("ERROR: expected newline after sequence chunk, found", nl) + return -3 + return slen + + +struct FileReader(KRead): + var fh: FileHandle + + fn __init__(out self, owned fh: FileHandle): + self.fh = fh^ + + fn __moveinit__(out self, owned other: Self): + self.fh = other.fh^ + + fn unbuffered_read[ + o: MutableOrigin + ](mut self, buffer: Span[UInt8, o]) raises -> Int: + return Int(self.fh.read(buffer.unsafe_ptr(), len(buffer))) + + +# ────────────────────────────────────────────────────────────── +# Main for debugging +# ────────────────────────────────────────────────────────────── +# +fn main() raises: + var argv = sys.argv() + if len(argv) != 2: + print("Usage: mojo run kseq.mojo ") + return + + var fh = open(String(argv[1]), "r") + var reader = FastxReader[read_comment=False]( + BufferedReader(FileReader(fh^)) + ) + + var first = reader.read_fastxpp() + print("first‑read returned", first) + + var count = 0 + while True: + var n = reader.read_fastxpp() + if n < 0: + break + count += 1 + print("rec#", count, "seq_len", n, "hdr_len", len(reader.name)) diff --git a/ishlib/vendor/swar_decode.mojo b/ishlib/vendor/swar_decode.mojo new file mode 100644 index 0000000..48fd011 --- /dev/null +++ b/ishlib/vendor/swar_decode.mojo @@ -0,0 +1,120 @@ +from memory import UnsafePointer +from sys import exit + +alias U8x8 = SIMD[DType.uint8, 8] +alias U32x8 = SIMD[DType.uint32, 8] +alias ASCII_ZERO = UInt8(ord("0")) + + +# 8 ASCII digits +@always_inline +fn decode_8(ptr: UnsafePointer[UInt8]) -> Int: + var v = ptr.load[width=8](0) - ASCII_ZERO + var w = v.cast[DType.uint32]() + var mul = U32x8(10_000_000, 1_000_000, 100_000, 10_000, 1_000, 100, 10, 1) + return Int((w * mul).reduce_add()) + + +# 6 digits (still load 8 lanes) +@always_inline +fn decode_6(ptr: UnsafePointer[UInt8]) -> Int: + var v = ptr.load[width=8](0) - ASCII_ZERO + var w = v.cast[DType.uint32]() + var mul = U32x8(100_000, 10_000, 1_000, 100, 10, 1, 0, 0) + return Int((w * mul).reduce_add()) + + +# 7 digits: scalar last digit +@always_inline +fn decode_7(ptr: UnsafePointer[UInt8]) -> Int: + # Read the 7th digit at ptr[6] (offset 6) + var last_digit = Int(ptr.offset(6).load[width=1](0) - ASCII_ZERO) + # decode_6 handles ptr[0] through ptr[5] + return decode_6(ptr) * 10 + last_digit + + +# 9 digits: scalar last digit +@always_inline +fn decode_9(ptr: UnsafePointer[UInt8]) -> Int: + # Read the 9th digit at ptr[8] (offset 8) + var last_digit = Int(ptr.offset(8).load[width=1](0) - ASCII_ZERO) + # decode_8 handles ptr[0] through ptr[7] + return decode_8(ptr) * 10 + last_digit + + +# 3-digit fallback +@always_inline +fn decode_3(ptr: UnsafePointer[UInt8]) -> Int: + var d0 = Int(ptr.load[width=1](0) - ASCII_ZERO) + var d1 = Int(ptr.load[width=1](1) - ASCII_ZERO) + var d2 = Int(ptr.load[width=1](2) - ASCII_ZERO) + + return d0 * 100 + d1 * 10 + d2 + + +@always_inline +fn decode[size: UInt](bstr: Span[UInt8]) -> Int: + constrained[ + size >= 3 and size <= 9, "size outside allowed range of 3 to 9" + ]() + + @parameter + if size == 3: + return decode_3(bstr.unsafe_ptr()) + elif size == 6: + return decode_6(bstr.unsafe_ptr()) + elif size == 7: + return decode_7(bstr.unsafe_ptr()) + elif size == 8: + return decode_8(bstr.unsafe_ptr()) + elif size == 9: + return decode_9(bstr.unsafe_ptr()) + else: + return -1 + + +# zero-pad an Int to a fixed width +fn zpad(value: Int, width: Int) -> String: + var s = String(value) + var pad = width - len(s) + var out = String(capacity=width) + for _ in range(pad): + out.write("0") + out.write(s) + return out + + +fn expect(label: String, got: Int, want: Int): + if got != want: + print("FAIL {", label, "} got ", got, " expected ", want) + exit(1) + + +fn run_tests(): + # 3-digit + expect("3-zero", decode_3("000".unsafe_ptr()), 0) + expect("3-123", decode_3("123".unsafe_ptr()), 123) + expect("3-999", decode_3("999".unsafe_ptr()), 999) + + # 6-digit + expect("6-zero", decode_6("000000".unsafe_ptr()), 0) + expect("6-00123", decode_6("000123".unsafe_ptr()), 123) + expect("6-654321", decode_6("654321".unsafe_ptr()), 654321) + + # 7-digit + expect("7-zero", decode_7("0000000".unsafe_ptr()), 0) + expect("7-1234567", decode_7("1234567".unsafe_ptr()), 1234567) + + # 8-digit + expect("8-zero", decode_8("00000000".unsafe_ptr()), 0) + expect("8-87654321", decode_8("87654321".unsafe_ptr()), 87654321) + + # 9-digit + expect("9-zero", decode_9("000000000".unsafe_ptr()), 0) + expect("9-123456789", decode_9("123456789".unsafe_ptr()), 123456789) + + print("All SWAR decode tests passed.") + + +fn main(): + run_tests()