Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 0 additions & 72 deletions res/pygments/ptx.py

This file was deleted.

172 changes: 152 additions & 20 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,165 @@ const Cthulhu = Base.PkgId(UUID("f68482b8-f384-11e8-15f7-abe071a5a75f"), "Cthulh
# syntax highlighting
#

const _pygmentize = Ref{Union{String,Nothing}}()
function pygmentize()
if !isassigned(_pygmentize)
_pygmentize[] = Sys.which("pygmentize")
end
return _pygmentize[]
end
# https://github.com/JuliaLang/julia/blob/dacd16f068fb27719b31effbe8929952ee2d5b32/stdlib/InteractiveUtils/src/codeview.jl
const hlscheme = Dict{Symbol, Tuple{Bool, Union{Symbol, Int}}}(
:default => (false, :normal), # e.g. comma, equal sign, unknown token
:comment => (false, :light_black),
:label => (false, :light_red),
:instruction => ( true, :light_cyan),
:type => (false, :cyan),
:number => (false, :yellow),
:bracket => (false, :yellow),
:variable => (false, :normal), # e.g. variable, register
:keyword => (false, :light_magenta),
:funcname => (false, :light_yellow),
)

function highlight(io::IO, code, lexer)
highlighter = pygmentize()
have_color = get(io, :color, false)
if highlighter === nothing || !have_color
if !haskey(io, :color)
print(io, code)
elseif lexer == "llvm"
InteractiveUtils.print_llvm(io, code)
elseif lexer == "ptx"
highlight_ptx(io, code)
else
custom_lexer = joinpath(dirname(@__DIR__), "res", "pygments", "$lexer.py")
if isfile(custom_lexer)
lexer = `$custom_lexer -x`
end

pipe = open(`$highlighter -f terminal -P bg=dark -l $lexer`, "r+")
print(pipe, code)
close(pipe.in)
print(io, read(pipe, String))
print(io, code)
end
return
end

ptx_instructions = ["abs", "activemask", "add", "addc", "alloca", "and",
"applypriority", "atom", "bar", "barrier", "bfe", "bfi",
"bfind", "bmsk", "bra", "brev", "brkpt", "brx", "call", "clz",
"cnot", "copysign", "cos", "cp", "createpolicy", "cvt", "cvta",
"discard", "div", "dp2a", "dp4a", "ex2", "exit", "fence",
"fma", "fns", "isspacep", "istypep", "ld", "ldmatrix", "ldu",
"lg2", "lop3", "mad", "mad24", "madc", "match", "max", "mbarrier",
"membar", "min", "mma", "mov", "mul", "mul24", "nanosleep", "neg",
"not", "or", "pmevent", "popc", "prefetch", "prefetchu", "prmt",
"rcp", "red", "redux", "rem", "ret", "rsqrt", "sad", "selp",
"set", "setp", "shf", "shfl", "shl", "shr", "sin", "slct", "sqrt",
"st", "stackrestore", "stacksave", "sub", "subc", "suld", "suq",
"sured", "sust", "szext", "tanh", "testp", "tex", "tld4", "trap",
"txq", "vabsdiff", "vabsdiff2", "vabsdiff4", "vadd", "vadd2", "vadd4",
"vavrg2", "vavrg4", "vmad", "vmax", "vmax2", "vmax4", "vmin", "vmin2",
"vmin4", "vote", "vset", "vset2", "vset4", "vshl", "vshr", "vsub",
"vsub2", "vsub4", "wmma", "xor"]

r_ptx_instruction = join(ptx_instructions, "|")

types = ["s8", "s16", "s32", "s64", "u8,", "u16,", "u32", "u64", "f16", "f16x2", "f32", "f64", "b8,", "b16", "b32", "b64", "pred"]
r_types = join(types, "|")


# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-and-bit-size-comparisons
operators_comparison_sint = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_uint = ["eq", "ne", "lo", "ls", "hi", "hs"]
operators_comparison_bit = ["eq", "ne"]

operators_comparison_float = ["eq", "ne", "lt", "le", "gt", "ge"]
operators_comparison_nanfloat = ["equ", "neu", "ltu", "leu", "gtu", "geu"]
operators_comparison_nan = ["num", "nan"]

modifiers_int = ["rni", "rzi", "rmi", "rpi"]
modifiers_float = ["rn", "rna", "rz", "rm", "rp"]
modifiers = sort(unique([modifiers_int...,modifiers_float...]))

state_spaces = ["reg", "sreg", "const", "global", "local", "param", "shared", "tex"]


operators = sort(unique([operators_comparison_sint..., operators_comparison_uint...,
operators_comparison_bit..., operators_comparison_float...,
operators_comparison_nanfloat..., operators_comparison_nan...,
modifiers..., state_spaces..., types...]))


r_operators = join(operators, "|")

# We can divide into types of instructions as all combinations of instructions, types and operators are not valid.
r_instruction = "(?:(?:$r_ptx_instruction)\\.(?:(?:$r_operators)(?:\\.))?(?:$(r_types)))"

directives = ["address_size", "align", "branchtargets", "callprototype",
"calltargets", "const", "entry", "extern", "file", "func", "global",
"loc", "local", "maxnctapersm", "maxnreg", "maxntid",
"minnctapersm", "param", "pragma", "reg", "reqntid", "section",
"shared", "sreg", "target", "tex", "version", "visible", "weak"]

r_directive = "(?:.(?:" * join(directives, "|") * "))"


r_hex = "0[xX][A-F]+U?"
r_octal = "0[0-8]+U?"
r_binary = "0[bB][01]+U?"
r_decimal = "[0-9]+U?"
r_float = "0[fF]{hexdigit}{8}"
r_double = "0[dD]{hexdigit}{16}"

r_number = join(map(x -> "(?:" * x * ")", [r_hex, r_octal, r_binary, r_decimal, r_float, r_double]), "|")

r_register_special = ["%clock", "%clock64", "%clock_hi", "%ctaid", "%dynamic_smem_size", "%envreg\\d{0,2}", # envreg0-31
"%globaltimer", "%globaltimer_hi", "%globaltimer_lo,", "%gridid", "%laneid", "%lanemask_eq",
"%lanemask_ge", "%lanemask_gt", "%lanemask_le", "%lanemask_lt", "%nctaid", "%nsmid",
"%ntid", "%nwarpid", "%pm\\d,", "%pm\\d_64", "%reserved_smem_offset<2>",
"%reserved_smem_offset_begin", "%reserved_smem_offset_cap", "%reserved_smem_offset_end", "%smid",
"%tid", "%total_smem_size", "%warpid", "%\\w{1,2}\\d{0,2}"]

r_register = join(r_register_special, "|")


r_followsym = "[a-zA-Z0-9_\$]"
r_identifier= "[a-zA-Z]{$r_followsym}* | {[_\$%]{$r_followsym}+"

r_guard_predicate = "@!?%p\\d{0,2}"
r_label = "[\\w_]+:"
r_comment = "//"
r_unknown = "[^\\s]*"

r_line = "(?:(?:$r_directive)|(?:$r_instruction)|(?:$r_register)|(?:$r_number)|(?:$r_label)|(?:$r_guard_predicate)|(?:$r_comment)|(?:$r_identifier)|(?:$r_unknown))"

get_token(n::Nothing) = nothing, nothing, nothing

# simple regex-based highlighter
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
function highlight_ptx(io::IO, code::AbstractString)
function get_token(s)
m = match(Regex("^(\\s*)($r_line)([^\\w\\d]+.*)?"), s)
m !== nothing && (return m.captures[1:3])
return nothing, nothing, nothing
end
get_token(n::Nothing) = nothing, nothing, nothing
print_tok(token, type) = Base.printstyled(io,
token,
bold = hlscheme[type][1],
color = hlscheme[type][2])
code = IOBuffer(code)
while !eof(code)
line = readline(code)
indent, tok, line = get_token(line)
is_tok(regex) = match(Regex("^(" * regex * ")"), tok) !== nothing
while (tok !== nothing)
print(io, indent)
if is_tok(r_comment)
print_tok(tok, :comment)
print_tok(line, :comment)
break
elseif is_tok(r_label)
print_tok(tok, :label)
elseif is_tok(r_instruction)
print_tok(tok, :instruction)
elseif is_tok(r_directive)
print_tok(tok, :type)
elseif is_tok(r_guard_predicate)
print_tok(tok, :keyword)
elseif is_tok(r_register)
print_tok(tok, :number)
else
print_tok(tok, :default)
end
indent, tok, line = get_token(line)
end
print(io, '\n')
end
end

#
# code_* replacements
Expand Down